"...git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "a6e310af7e82e484150b9980c17c2a100e601a53"
Unverified Commit d0742860 authored by Gabriel Wu's avatar Gabriel Wu Committed by GitHub
Browse files

[Chore] fix typos (#719)

* chore: fix typos

* chore: fix ruff

* chore: fix clang-format
parent 6545b084
...@@ -53,10 +53,7 @@ def get_configs(args, kwargs): ...@@ -53,10 +53,7 @@ def get_configs(args, kwargs):
from tilelang.carver.roller.rasterization import NoRasterization from tilelang.carver.roller.rasterization import NoRasterization
import torch import torch
if torch.version.hip is not None: arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
arch=CDNA("hip")
else:
arch = CUDA("cuda")
topk = 10 topk = 10
carve_template = MatmulTemplate( carve_template = MatmulTemplate(
......
...@@ -187,10 +187,7 @@ def get_configs(args, kwargs): ...@@ -187,10 +187,7 @@ def get_configs(args, kwargs):
from tilelang.carver.roller.rasterization import NoRasterization from tilelang.carver.roller.rasterization import NoRasterization
import torch import torch
if torch.version.hip is not None: arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
arch=CDNA("hip")
else:
arch = CUDA("cuda")
topk = 10 topk = 10
carve_template = MatmulTemplate( carve_template = MatmulTemplate(
......
...@@ -252,7 +252,7 @@ def splitk_gemv_vectorized( ...@@ -252,7 +252,7 @@ def splitk_gemv_vectorized(
return main return main
``` ```
With vectorized read, now the kernel finishs in **~0.0084 ms**, which is getting close to cuBLAS performance. With vectorized read, now the kernel finishes in **~0.0084 ms**, which is getting close to cuBLAS performance.
## `tvm_thread_allreduce` Instead of `atomicAdd` ## `tvm_thread_allreduce` Instead of `atomicAdd`
......
...@@ -4,6 +4,7 @@ from tilelang.carver.arch import CUDA ...@@ -4,6 +4,7 @@ from tilelang.carver.arch import CUDA
from tilelang.carver.arch import CDNA from tilelang.carver.arch import CDNA
from tilelang.layout import make_swizzled_layout from tilelang.layout import make_swizzled_layout
import torch import torch
N = 64 N = 64
C = 256 C = 256
H = 512 H = 512
...@@ -95,10 +96,7 @@ def kernel(N, ...@@ -95,10 +96,7 @@ def kernel(N,
def main(): def main():
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256) my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
if torch.version.hip is not None: cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
cuda_device=CDNA("hip")
else:
cuda_device = CUDA("cuda")
result = Analyzer.analysis(my_func, cuda_device) result = Analyzer.analysis(my_func, cuda_device)
print(result) print(result)
print(f"Analyzed FLOPs: {result.total_flops}") print(f"Analyzed FLOPs: {result.total_flops}")
......
...@@ -49,10 +49,7 @@ def kernel( ...@@ -49,10 +49,7 @@ def kernel(
def main(): def main():
my_func = kernel(128, 128, 32, 3, 128, True) my_func = kernel(128, 128, 32, 3, 128, True)
if torch.version.hip is not None: cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
cuda_device=CDNA("hip")
else:
cuda_device = CUDA("cuda")
result = Analyzer.analysis(my_func, cuda_device) result = Analyzer.analysis(my_func, cuda_device)
print(f"Analyzed FLOPs: {result.total_flops}") print(f"Analyzed FLOPs: {result.total_flops}")
......
...@@ -1373,7 +1373,7 @@ class BitnetForCausalLM(BitnetPreTrainedModel): ...@@ -1373,7 +1373,7 @@ class BitnetForCausalLM(BitnetPreTrainedModel):
cache_length + input_ids.shape[1] > max_cache_length): cache_length + input_ids.shape[1] > max_cache_length):
attention_mask = attention_mask[:, -max_cache_length:] attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids")
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation # create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
......
...@@ -16,10 +16,7 @@ def ref_program(A, B): ...@@ -16,10 +16,7 @@ def ref_program(A, B):
def get_configs(M, N, K, with_roller=False, topk=20): def get_configs(M, N, K, with_roller=False, topk=20):
if with_roller: if with_roller:
if torch.version.hip is not None: arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
arch=CDNA("hip")
else:
arch = CUDA("cuda")
carve_template = MatmulTemplate( carve_template = MatmulTemplate(
M=M, M=M,
N=N, N=N,
......
...@@ -230,7 +230,7 @@ Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { ...@@ -230,7 +230,7 @@ Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
<< " and " << B.scope(); << " and " << B.scope();
ICHECK((E.scope() == "shared" || E.scope() == "shared.dyn")) ICHECK((E.scope() == "shared" || E.scope() == "shared.dyn"))
<< "Only support shared.dyn scope for E as copy from smem to rmem are " << "Only support shared.dyn scope for E as copy from smem to rmem are "
"delegated to cute implemntation, found " "delegated to cute implementation, found "
<< E.scope(); << E.scope();
ss << op_name << "<" << M << ", " << N << ", " << K << ", "; ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
ss << warp_m << ", " << warp_n << ", "; ss << warp_m << ", " << warp_n << ", ";
......
...@@ -95,7 +95,7 @@ private: ...@@ -95,7 +95,7 @@ private:
Array<String> function_names_; Array<String> function_names_;
/*! \brief whether to emit asserts in the resulting C code */ /*! \brief whether to emit asserts in the resulting C code */
bool emit_asserts_; bool emit_asserts_;
/*! \brief whether to emit forwared function declarations in the resulting C /*! \brief whether to emit forward function declarations in the resulting C
* code */ * code */
bool emit_fwd_func_decl_; bool emit_fwd_func_decl_;
......
...@@ -252,9 +252,9 @@ CodeGenTileLangWebGPU::AddFunction(const PrimFunc &f, bool skip_readonly_decl) { ...@@ -252,9 +252,9 @@ CodeGenTileLangWebGPU::AddFunction(const PrimFunc &f, bool skip_readonly_decl) {
os_param_access << "]"; os_param_access << "]";
func_info.launch_param_tags.push_back(os_param_access.str()); func_info.launch_param_tags.push_back(os_param_access.str());
ICHECK(!info.has_block_index_z) ICHECK(!info.has_block_index_z) << "blockIdx.z is not supported in WebGPU to "
<< "blockIdx.z is not supported in WebGPU to accomodate large blockIdx.x"; "accommodate large blockIdx.x";
// anotate workgroup // annotate workgroup
this->stream << "@compute @workgroup_size(" << info.workgroup_size[0] << ", " this->stream << "@compute @workgroup_size(" << info.workgroup_size[0] << ", "
<< info.workgroup_size[1] << ", " << info.workgroup_size[2] << info.workgroup_size[1] << ", " << info.workgroup_size[2]
<< ")\n"; << ")\n";
......
...@@ -284,7 +284,7 @@ ...@@ -284,7 +284,7 @@
#endif #endif
#ifndef HALF_ENABLE_F16C_INTRINSICS #ifndef HALF_ENABLE_F16C_INTRINSICS
/// Enable F16C intruction set intrinsics. /// Enable F16C instruction set intrinsics.
/// Defining this to 1 enables the use of [F16C compiler /// Defining this to 1 enables the use of [F16C compiler
/// intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between /// intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between
/// half-precision and single-precision values which may result in improved /// half-precision and single-precision values which may result in improved
...@@ -1674,7 +1674,7 @@ template <typename T> T half2float(unsigned int value) { ...@@ -1674,7 +1674,7 @@ template <typename T> T half2float(unsigned int value) {
/// \tparam R rounding mode to use /// \tparam R rounding mode to use
/// \tparam E `true` for round to even, `false` for round away from zero /// \tparam E `true` for round to even, `false` for round away from zero
/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never /// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never
/// raise it \tparam T type to convert to (buitlin integer type with at least 16 /// raise it \tparam T type to convert to (builtin integer type with at least 16
/// bits precision, excluding any implicit sign bits) \param value /// bits precision, excluding any implicit sign bits) \param value
/// half-precision value to convert \return rounded integer value \exception /// half-precision value to convert \return rounded integer value \exception
/// FE_INVALID if value is not representable in type \a T \exception FE_INEXACT /// FE_INVALID if value is not representable in type \a T \exception FE_INEXACT
...@@ -1778,7 +1778,7 @@ inline uint32 divide64(uint32 x, uint32 y, int &s) { ...@@ -1778,7 +1778,7 @@ inline uint32 divide64(uint32 x, uint32 y, int &s) {
/// \tparam R `true` to compute signed remainder, `false` for positive remainder /// \tparam R `true` to compute signed remainder, `false` for positive remainder
/// \param x first operand as positive finite half-precision value /// \param x first operand as positive finite half-precision value
/// \param y second operand as positive finite half-precision value /// \param y second operand as positive finite half-precision value
/// \param quo adress to store quotient at, `nullptr` if \a Q `false` /// \param quo address to store quotient at, `nullptr` if \a Q `false`
/// \return modulus of \a x / \a y /// \return modulus of \a x / \a y
template <bool Q, bool R> template <bool Q, bool R>
unsigned int mod(unsigned int x, unsigned int y, int *quo = NULL) { unsigned int mod(unsigned int x, unsigned int y, int *quo = NULL) {
...@@ -2435,7 +2435,7 @@ template <typename, typename, std::float_round_style> struct half_caster; ...@@ -2435,7 +2435,7 @@ template <typename, typename, std::float_round_style> struct half_caster;
/// Half-precision floating-point type. /// Half-precision floating-point type.
/// This class implements an IEEE-conformant half-precision floating-point type /// This class implements an IEEE-conformant half-precision floating-point type
/// with the usual arithmetic operators and conversions. It is implicitly /// with the usual arithmetic operators and conversions. It is implicitly
/// convertible to single-precision floating-point, which makes artihmetic /// convertible to single-precision floating-point, which makes arithmetic
/// expressions and functions with mixed-type operands to be of the most precise /// expressions and functions with mixed-type operands to be of the most precise
/// operand type. /// operand type.
/// ///
...@@ -2445,9 +2445,9 @@ template <typename, typename, std::float_round_style> struct half_caster; ...@@ -2445,9 +2445,9 @@ template <typename, typename, std::float_round_style> struct half_caster;
/// which means it can be standard-conformantly copied using raw binary copies. /// which means it can be standard-conformantly copied using raw binary copies.
/// But in this context some more words about the actual size of the type. /// But in this context some more words about the actual size of the type.
/// Although the half is representing an IEEE 16-bit type, it does not /// Although the half is representing an IEEE 16-bit type, it does not
/// neccessarily have to be of exactly 16-bits size. But on any reasonable /// necessarily have to be of exactly 16-bits size. But on any reasonable
/// implementation the actual binary representation of this type will most /// implementation the actual binary representation of this type will most
/// probably not ivolve any additional "magic" or padding beyond the simple /// probably not involve any additional "magic" or padding beyond the simple
/// binary representation of the underlying 16-bit IEEE number, even if not /// binary representation of the underlying 16-bit IEEE number, even if not
/// strictly guaranteed by the standard. But even then it only has an actual /// strictly guaranteed by the standard. But even then it only has an actual
/// size of 16 bits if your C++ implementation supports an unsigned integer type /// size of 16 bits if your C++ implementation supports an unsigned integer type
...@@ -2801,7 +2801,7 @@ public: ...@@ -2801,7 +2801,7 @@ public:
static HALF_CONSTEXPR_CONST bool traps = true; static HALF_CONSTEXPR_CONST bool traps = true;
#else #else
/// Traps only if [HALF_ERRHANDLING_THROW_...](\ref /// Traps only if [HALF_ERRHANDLING_THROW_...](\ref
/// HALF_ERRHANDLING_THROW_INVALID) is acitvated. /// HALF_ERRHANDLING_THROW_INVALID) is activated.
static HALF_CONSTEXPR_CONST bool traps = false; static HALF_CONSTEXPR_CONST bool traps = false;
#endif #endif
...@@ -5067,7 +5067,7 @@ inline half frexp(half arg, int *exp) { ...@@ -5067,7 +5067,7 @@ inline half frexp(half arg, int *exp) {
/// [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn). /// [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn).
/// \param arg number to modify /// \param arg number to modify
/// \param exp power of two to multiply with /// \param exp power of two to multiply with
/// \return \a arg multplied by 2 raised to \a exp /// \return \a arg multiplied by 2 raised to \a exp
/// \exception FE_INVALID for signaling NaN /// \exception FE_INVALID for signaling NaN
/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding
inline half scalbln(half arg, long exp) { inline half scalbln(half arg, long exp) {
...@@ -5096,7 +5096,7 @@ inline half scalbln(half arg, long exp) { ...@@ -5096,7 +5096,7 @@ inline half scalbln(half arg, long exp) {
/// **See also:** Documentation for /// **See also:** Documentation for
/// [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). \param /// [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). \param
/// arg number to modify \param exp power of two to multiply with \return \a arg /// arg number to modify \param exp power of two to multiply with \return \a arg
/// multplied by 2 raised to \a exp \exception FE_INVALID for signaling NaN /// multiplied by 2 raised to \a exp \exception FE_INVALID for signaling NaN
/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding
inline half scalbn(half arg, int exp) { return scalbln(arg, exp); } inline half scalbn(half arg, int exp) { return scalbln(arg, exp); }
...@@ -5106,7 +5106,7 @@ inline half scalbn(half arg, int exp) { return scalbln(arg, exp); } ...@@ -5106,7 +5106,7 @@ inline half scalbn(half arg, int exp) { return scalbln(arg, exp); }
/// **See also:** Documentation for /// **See also:** Documentation for
/// [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). \param /// [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). \param
/// arg number to modify \param exp power of two to multiply with \return \a arg /// arg number to modify \param exp power of two to multiply with \return \a arg
/// multplied by 2 raised to \a exp \exception FE_INVALID for signaling NaN /// multiplied by 2 raised to \a exp \exception FE_INVALID for signaling NaN
/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding
inline half ldexp(half arg, int exp) { return scalbln(arg, exp); } inline half ldexp(half arg, int exp) { return scalbln(arg, exp); }
...@@ -5379,7 +5379,7 @@ inline HALF_CONSTEXPR bool islessequal(half x, half y) { ...@@ -5379,7 +5379,7 @@ inline HALF_CONSTEXPR bool islessequal(half x, half y) {
!isnan(x) && !isnan(y); !isnan(x) && !isnan(y);
} }
/// Quiet comarison for less or greater. /// Quiet comparison for less or greater.
/// **See also:** Documentation for /// **See also:** Documentation for
/// [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater). /// [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater).
/// \param x first operand /// \param x first operand
...@@ -5503,7 +5503,7 @@ inline int feraiseexcept(int excepts) { ...@@ -5503,7 +5503,7 @@ inline int feraiseexcept(int excepts) {
/// ///
/// **See also:** Documentation for /// **See also:** Documentation for
/// [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). /// [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag).
/// \param flagp adress to store flag state at /// \param flagp address to store flag state at
/// \param excepts OR of flags to save /// \param excepts OR of flags to save
/// \retval 0 for success /// \retval 0 for success
inline int fegetexceptflag(int *flagp, int excepts) { inline int fegetexceptflag(int *flagp, int excepts) {
...@@ -5520,7 +5520,7 @@ inline int fegetexceptflag(int *flagp, int excepts) { ...@@ -5520,7 +5520,7 @@ inline int fegetexceptflag(int *flagp, int excepts) {
/// ///
/// **See also:** Documentation for /// **See also:** Documentation for
/// [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). /// [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag).
/// \param flagp adress to take flag state from /// \param flagp address to take flag state from
/// \param excepts OR of flags to restore /// \param excepts OR of flags to restore
/// \retval 0 for success /// \retval 0 for success
inline int fesetexceptflag(const int *flagp, int excepts) { inline int fesetexceptflag(const int *flagp, int excepts) {
......
...@@ -48,7 +48,7 @@ using int4_t = int4; ...@@ -48,7 +48,7 @@ using int4_t = int4;
} \ } \
} while (0) } while (0)
// abs function for bfloat_t and half_t since there is no implicit convertion // abs function for bfloat_t and half_t since there is no implicit conversion
// method // method
TL_PATCH TL_DEVICE half_t __habs(const half_t x) { TL_PATCH TL_DEVICE half_t __habs(const half_t x) {
return half_t(__habs(x.to_half())); return half_t(__habs(x.to_half()));
......
...@@ -118,7 +118,7 @@ debug_print_buffer_value<signed char>(const char *msg, const char *buf_name, ...@@ -118,7 +118,7 @@ debug_print_buffer_value<signed char>(const char *msg, const char *buf_name,
threadIdx.z, buf_name, index, var); threadIdx.z, buf_name, index, var);
} }
// Specialization for unsiged char type // Specialization for unsigned char type
template <> template <>
__device__ void __device__ void
debug_print_buffer_value<unsigned char>(const char *msg, const char *buf_name, debug_print_buffer_value<unsigned char>(const char *msg, const char *buf_name,
......
/*! /*!
* \file atomicadd_vectorize.cc * \file atomicadd_vectorize.cc
* \brief A tool to atomatically vectorize atomic add * \brief A tool to automatically vectorize atomic add
*/ */
#include "../layout/layout.h" #include "../layout/layout.h"
......
...@@ -303,7 +303,7 @@ private: ...@@ -303,7 +303,7 @@ private:
bool IsAppropriateSharedMemory(const Var &var) { bool IsAppropriateSharedMemory(const Var &var) {
return is_dynamic_ ? IsDynamicSharedMemory(var) : IsStaticSharedMemory(var); return is_dynamic_ ? IsDynamicSharedMemory(var) : IsStaticSharedMemory(var);
} }
// Whether do dyanmic analysis. // Whether do dynamic analysis.
bool is_dynamic_{true}; bool is_dynamic_{true};
// Whether do aggressive merge. // Whether do aggressive merge.
bool enable_aggressive_merge_{false}; bool enable_aggressive_merge_{false};
...@@ -435,7 +435,7 @@ private: ...@@ -435,7 +435,7 @@ private:
const AllocateNode *alloc = shmem_allocs_[buffer]; const AllocateNode *alloc = shmem_allocs_[buffer];
auto alignment = align[i]; auto alignment = align[i];
// Modern nvidia architecture performs hardware swizzling (hopper // Modern nvidia architecture performs hardware swizzling (hopper
// wgmma/tma for exmaple) requires dynamic shared memory address to // wgmma/tma for example) requires dynamic shared memory address to
// be aligned to 1024 bytes For other devices, we align to 16 bytes // be aligned to 1024 bytes For other devices, we align to 16 bytes
if (shmem_alignment_map_.find(buffer) != if (shmem_alignment_map_.find(buffer) !=
shmem_alignment_map_.end()) { shmem_alignment_map_.end()) {
...@@ -943,7 +943,7 @@ private: ...@@ -943,7 +943,7 @@ private:
*/ */
StorageEntry *NewAlloc(const AllocateNode *op, size_t const_nbits) { StorageEntry *NewAlloc(const AllocateNode *op, size_t const_nbits) {
ICHECK(op != nullptr); ICHECK(op != nullptr);
// Re-use not successful, allocate a new buffer. // Reuse not successful, allocate a new buffer.
StorageEntry *entry = arena_.make<StorageEntry>(); StorageEntry *entry = arena_.make<StorageEntry>();
entry->allocs.push_back({op->buffer_var.get()}); entry->allocs.push_back({op->buffer_var.get()});
entry->const_nbits = const_nbits; entry->const_nbits = const_nbits;
...@@ -1046,7 +1046,7 @@ private: ...@@ -1046,7 +1046,7 @@ private:
sym_free_list_.push_back(e); sym_free_list_.push_back(e);
} }
} }
// Wheather enable dyanmic analysis. // Whether enable dynamic analysis.
bool is_dynamic_{true}; bool is_dynamic_{true};
// Whether enable verbose logging. // Whether enable verbose logging.
......
...@@ -140,9 +140,9 @@ public: ...@@ -140,9 +140,9 @@ public:
// //
class LinearAccessPatternFinder final : public StmtExprVisitor { class LinearAccessPatternFinder final : public StmtExprVisitor {
public: public:
/*! \brief record the touch hist of statment. */ /*! \brief record the touch hist of statement. */
struct StmtEntry { struct StmtEntry {
// The statment // The statement
const Object *stmt; const Object *stmt;
// The index in the linear_seq_ to point to end of the nested scope. // The index in the linear_seq_ to point to end of the nested scope.
// This is only set to non-zero if stmt is a nested scope. // This is only set to non-zero if stmt is a nested scope.
...@@ -150,7 +150,7 @@ public: ...@@ -150,7 +150,7 @@ public:
// offset if offset < 0, means this is the end, the begin entry is // offset if offset < 0, means this is the end, the begin entry is
// current_index + offset // current_index + offset
int64_t scope_pair_offset{0}; int64_t scope_pair_offset{0};
// The buffer variables this statment touched. // The buffer variables this statement touched.
std::vector<const VarNode *> touched; std::vector<const VarNode *> touched;
}; };
// The scope of each allocation // The scope of each allocation
...@@ -675,7 +675,7 @@ private: ...@@ -675,7 +675,7 @@ private:
scope.tag != ".workspace" && scope.tag != ".vtcm"; scope.tag != ".workspace" && scope.tag != ".vtcm";
} }
// Alllocate entry of node. // Allocate entry of node.
// Event entry in liveness analysis // Event entry in liveness analysis
struct EventEntry { struct EventEntry {
// variables we generate // variables we generate
...@@ -785,10 +785,10 @@ private: ...@@ -785,10 +785,10 @@ private:
for (const AllocateNode *op : e->allocs) { for (const AllocateNode *op : e->allocs) {
ICHECK_EQ(op->extents.size(), 1) ICHECK_EQ(op->extents.size(), 1)
<< "Buffer var " << op->buffer_var->name_hint << "Buffer var " << op->buffer_var->name_hint
<< " was identified as a re-usable allocation, but has " << " was identified as a reusable allocation, but has "
<< op->extents.size() << " physical dimensions. " << op->extents.size() << " physical dimensions. "
<< "Currently, only flat 1-d memory spaces should be " << "Currently, only flat 1-d memory spaces should be "
"identified as re-usable " "identified as reusable "
"allocations."; "allocations.";
PrimExpr sz = op->extents[0]; PrimExpr sz = op->extents[0];
auto nbits = op->dtype.bits() * op->dtype.lanes(); auto nbits = op->dtype.bits() * op->dtype.lanes();
...@@ -905,7 +905,7 @@ private: ...@@ -905,7 +905,7 @@ private:
void PlanNewScope(const Object *op) { void PlanNewScope(const Object *op) {
if (thread_scope_ != nullptr) { if (thread_scope_ != nullptr) {
ICHECK(thread_scope_ == op); ICHECK(thread_scope_ == op);
// erase all memory atatched to this scope. // erase all memory attached to this scope.
for (auto it = const_free_map_.begin(); it != const_free_map_.end();) { for (auto it = const_free_map_.begin(); it != const_free_map_.end();) {
if (it->second->attach_scope_ == op) { if (it->second->attach_scope_ == op) {
it = const_free_map_.erase(it); it = const_free_map_.erase(it);
...@@ -1023,7 +1023,7 @@ private: ...@@ -1023,7 +1023,7 @@ private:
StorageEntry *NewAlloc(const AllocateNode *op, const Object *attach_scope, StorageEntry *NewAlloc(const AllocateNode *op, const Object *attach_scope,
const StorageScope &scope, size_t const_nbits) { const StorageScope &scope, size_t const_nbits) {
ICHECK(op != nullptr); ICHECK(op != nullptr);
// Re-use not successful, allocate a new buffer. // Reuse not successful, allocate a new buffer.
auto entry = std::make_unique<StorageEntry>(); auto entry = std::make_unique<StorageEntry>();
entry->attach_scope_ = attach_scope; entry->attach_scope_ = attach_scope;
entry->scope = scope; entry->scope = scope;
...@@ -1050,7 +1050,7 @@ private: ...@@ -1050,7 +1050,7 @@ private:
// have its own allocation with size determined at runtime. // have its own allocation with size determined at runtime.
bool is_known_size = (const_nbits != 0); bool is_known_size = (const_nbits != 0);
// Currently, only flat memory spaces can be re-used. Packing // Currently, only flat memory spaces can be reused. Packing
// into N-d space (e.g. 2-d texture memory on GPUs) will require // into N-d space (e.g. 2-d texture memory on GPUs) will require
// more in-depth algorithms. // more in-depth algorithms.
bool is_flat_memory_space = (num_physical_dimensions == 1); bool is_flat_memory_space = (num_physical_dimensions == 1);
......
...@@ -189,7 +189,7 @@ protected: ...@@ -189,7 +189,7 @@ protected:
} }
} }
} }
// return the exposed entries, remove unecessary ones. // return the exposed entries, remove unnecessary ones.
int sync_count = 0; int sync_count = 0;
// head are before first sync, tail are after last sync // head are before first sync, tail are after last sync
std::vector<AccessEntry> head, tail; std::vector<AccessEntry> head, tail;
......
...@@ -527,7 +527,7 @@ public: ...@@ -527,7 +527,7 @@ public:
// A single var can be binded in multiple lets // A single var can be binded in multiple lets
// but they have to bind to the same value. // but they have to bind to the same value.
// This is used to allow cases when we reuse a single let // This is used to allow cases when we reuse a single let
// expression to cosntruct a nested expr. // expression to construct a nested expr.
// (let x = 1 in x + 1) * (let x = 1 in x + 1) // (let x = 1 in x + 1) * (let x = 1 in x + 1)
auto it = let_binding_.find(op->var); auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) { if (it != let_binding_.end()) {
...@@ -683,7 +683,7 @@ public: ...@@ -683,7 +683,7 @@ public:
return StmtMutator::VisitStmt_(op); return StmtMutator::VisitStmt_(op);
} }
// scalarize the statment // scalarize the statement
Stmt Scalarize(Stmt stmt) { Stmt Scalarize(Stmt stmt) {
Var idx(var_->name_hint + ".s", var_->dtype); Var idx(var_->name_hint + ".s", var_->dtype);
stmt = Substitute(stmt, {{var_, idx}}); stmt = Substitute(stmt, {{var_, idx}});
...@@ -701,7 +701,7 @@ private: ...@@ -701,7 +701,7 @@ private:
PrimExpr var_lanes_; PrimExpr var_lanes_;
// ramp representing the var. // ramp representing the var.
PrimExpr ramp_; PrimExpr ramp_;
// flag to mark requirment of scalarization. // flag to mark requirement of scalarization.
bool need_scalarize_{false}; bool need_scalarize_{false};
// Let binding // Let binding
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_; std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
......
...@@ -88,6 +88,7 @@ def reshape_test_smem_2d_2_1d(N, M, dtype): ...@@ -88,6 +88,7 @@ def reshape_test_smem_2d_2_1d(N, M, dtype):
return main return main
def run_reshape_smem_2d_2_1d(N, M, dtype): def run_reshape_smem_2d_2_1d(N, M, dtype):
program = reshape_test_smem_2d_2_1d(N, M, dtype) program = reshape_test_smem_2d_2_1d(N, M, dtype)
jit_kernel = tl.compile(program, out_idx=-1) jit_kernel = tl.compile(program, out_idx=-1)
...@@ -98,11 +99,11 @@ def run_reshape_smem_2d_2_1d(N, M, dtype): ...@@ -98,11 +99,11 @@ def run_reshape_smem_2d_2_1d(N, M, dtype):
profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)
def test_reshape_smem_2d_2_1d(): def test_reshape_smem_2d_2_1d():
run_reshape_smem_2d_2_1d(1024, 32, "float32") run_reshape_smem_2d_2_1d(1024, 32, "float32")
run_reshape_smem_2d_2_1d(2048, 64, "float16") run_reshape_smem_2d_2_1d(2048, 64, "float16")
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -203,7 +203,7 @@ class AutoTuner: ...@@ -203,7 +203,7 @@ class AutoTuner:
logger.warning( logger.warning(
"`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context." "`supply_prog` will be ignored as this program is under `with set_autotune_inputs` context."
) )
supply_prog = lambda _: get_autotune_inputs() # noqa: E731· supply_prog = lambda _: get_autotune_inputs() # noqa: E731
self.profile_args = ProfileArgs( self.profile_args = ProfileArgs(
supply_type=supply_type, supply_type=supply_type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment