Unverified Commit dda45126 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Refactor] Reduce direct dependency on PyTorch due to its limited type support (#1444)



* [Enhancement] Update KernelParam to use tvm.DataType directly and add torch_dtype conversion method

- Changed dtype in KernelParam from torch.dtype to tvm.DataType to support a wider range of data types and prevent information loss during conversions.
- Added a new method, torch_dtype, to convert tvm.DataType back to torch.dtype for tensor creation.
- Updated various adapters to utilize the new torch_dtype method for parameter type conversion during initialization.

* [Enhancement] Refactor CUDA type handling and add support for FP4 and FP8 types

- Renamed functions for clarity: GetFP8Type, GetFP6Type, and GetFP4Type are now GetTileLangFP8Type, GetTileLangFP6Type, and GetTileLangFP4Type respectively.
- Enhanced FP4 type handling to support additional lane sizes (2, 4, 8, 16, 32, 64).
- Updated CUDA code generation to include new FP8 and FP4 types, ensuring proper type handling in PrintType and related functions.
- Introduced new structures for FP8 types in cuda_fp8.h to facilitate better memory management and type packing.
- Added methods in KernelParam and tensor utilities to recognize and handle float4 types, improving compatibility with PyTorch.
- Enhanced logging for debugging purposes in various CUDA functions to track type handling and memory operations more effectively.

* lint fix

* Remove unnecessary logging statements from CUDA code generation and delete obsolete matrix multiplication test file.

* [Enhancement] Add support for FP4 and FP8 types in CUDA code generation

- Enhanced PrintVecElemLoad and PrintVecElemStore functions to handle new FP4 types.
- Updated arg_binder to allow float4 to match int8 at runtime, improving compatibility with PyTorch.
- Modified loop_vectorize to account for buffer dtype lanes in vectorization calculations.
- Refactored tensor type mapping to support new float4 and float8 types, ensuring correct type handling in tensor operations.
- Added tests for FP4 and FP8 copy operations to validate functionality and integration with existing workflows.

---------
Co-authored-by: default avatarZhiwen Mo <zm125@ic.ac.uk>
parent 81b8c1b7
...@@ -107,7 +107,7 @@ struct CUDAIEEEMath { ...@@ -107,7 +107,7 @@ struct CUDAIEEEMath {
} }
}; };
static std::string GetFP8Type(DataType type) { static std::string GetTileLangFP8Type(DataType type) {
std::stringstream stream; std::stringstream stream;
int32_t lanes = type.lanes(); int32_t lanes = type.lanes();
std::string vec; std::string vec;
...@@ -134,13 +134,15 @@ static std::string GetFP8Type(DataType type) { ...@@ -134,13 +134,15 @@ static std::string GetFP8Type(DataType type) {
} else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz() || } else if (type.is_float8_e5m2() || type.is_float8_e5m2fnuz() ||
type.is_float8_e5m2()) { type.is_float8_e5m2()) {
stream << "fp8_e5" << vec << "_t"; stream << "fp8_e5" << vec << "_t";
} else if (type.is_float8_e8m0fnu()) {
stream << "fp8_e8" << vec << "_t";
} else { } else {
LOG(FATAL) << "Unsupported FP8 type in CUDA codegen but got " << type; LOG(FATAL) << "Unsupported FP8 type in CUDA codegen but got " << type;
} }
return stream.str(); return stream.str();
} }
std::string GetFP6Type(DataType type) { std::string GetTileLangFP6Type(DataType type) {
std::stringstream stream; std::stringstream stream;
int32_t lanes = type.lanes(); int32_t lanes = type.lanes();
std::string vec; std::string vec;
...@@ -171,32 +173,37 @@ std::string GetFP6Type(DataType type) { ...@@ -171,32 +173,37 @@ std::string GetFP6Type(DataType type) {
return stream.str(); return stream.str();
} }
std::string GetFP4Type(DataType type) { std::string GetTileLangFP4Type(DataType type) {
std::stringstream stream; std::stringstream stream;
int32_t lanes = type.lanes(); int32_t lanes = type.lanes();
std::string vec; std::string vec;
if (type.is_scalar()) { if (type.is_scalar()) {
vec = ""; vec = "";
} else if (lanes == 2) { } else if (lanes == 2) {
vec = "x2"; vec = "_2";
} else if (lanes == 4) { } else if (lanes == 4) {
vec = "x4"; vec = "_4";
} else if (lanes == 8) { } else if (lanes == 8) {
vec = "x8"; vec = "_8";
} else if (lanes == 16) { } else if (lanes == 16) {
vec = "x16"; vec = "_16";
} else if (lanes == 32) {
vec = "_32";
} else if (lanes == 64) {
vec = "_64";
} else { } else {
LOG(FATAL) LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16, "
<< "Only support scalar and vector types of width (2, 4) for FP4"; "32, 64) for FP4";
} }
stream << "__nv_fp4";
std::string suffix; std::string suffix;
if (type.code() == DataType::kFloat4_e2m1fn) { if (type.code() == DataType::kFloat4_e2m1fn) {
suffix = "_e2m1"; suffix = "_e2";
} else { } else {
LOG(FATAL) << "Unsupported FP4 type in CUDA codegen"; LOG(FATAL) << "Unsupported FP4 type in CUDA codegen";
} }
stream << vec << suffix;
stream << "fp4" << suffix << vec << "_t";
return stream.str(); return stream.str();
} }
...@@ -278,6 +285,9 @@ std::string CodeGenTileLangCUDA::Finish() { ...@@ -278,6 +285,9 @@ std::string CodeGenTileLangCUDA::Finish() {
if (enable_fp8_) { if (enable_fp8_) {
decl_stream << "#include <tl_templates/cuda/cuda_fp8.h>\n"; decl_stream << "#include <tl_templates/cuda/cuda_fp8.h>\n";
} }
if (enable_fp4_) {
decl_stream << "#include <tl_templates/cuda/cuda_fp4.h>\n";
}
if (need_math_constants_h_) { if (need_math_constants_h_) {
decl_stream << "#include <math_constants.h>\n"; decl_stream << "#include <math_constants.h>\n";
...@@ -437,18 +447,20 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ...@@ -437,18 +447,20 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
return; return;
} else if (t.is_float8()) { } else if (t.is_float8()) {
enable_fp8_ = true; enable_fp8_ = true;
os << GetFP8Type(t); os << GetTileLangFP8Type(t);
return; return;
} else if (t.is_float6()) { } else if (t.is_float6()) {
enable_fp6_ = true; enable_fp6_ = true;
if (t.lanes() <= 4) { if (t.lanes() <= 4) {
os << GetFP6Type(t); os << GetTileLangFP6Type(t);
} }
return; return;
} else if (t.is_float4()) { } else if (t.is_float4()) {
enable_fp4_ = true; enable_fp4_ = true;
if (t.lanes() <= 4) { if (t.lanes() <= 64) {
os << GetFP4Type(t); os << GetTileLangFP4Type(t);
} else {
fail = true;
} }
return; return;
} else if (t == DataType::Bool()) { } else if (t == DataType::Bool()) {
...@@ -665,7 +677,9 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t, ...@@ -665,7 +677,9 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t,
} }
static const char access[] = {'x', 'y', 'z', 'w'}; static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < 256 / t.bits()); ICHECK(i >= 0 && i < 256 / t.bits())
<< "i: " << i << " t: " << t << " t.bits(): " << t.bits()
<< " t.lanes(): " << t.lanes();
if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
std::string type_name = t.is_int() ? "char" : "unsigned char"; std::string type_name = t.is_int() ? "char" : "unsigned char";
if (t.lanes() == 2 || t.lanes() == 3) { if (t.lanes() == 2 || t.lanes() == 3) {
...@@ -707,6 +721,22 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t, ...@@ -707,6 +721,22 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t,
os << "." << access[(i % 8) / 4]; os << "." << access[(i % 8) / 4];
// fp8_e5_4_t or fp8_e5_2_t // fp8_e5_4_t or fp8_e5_2_t
os << "." << access[i % 4]; os << "." << access[i % 4];
} else if (t.is_float4_e2m1fn()) {
os << vec;
// fp4_e2_64_t
if (t.lanes() >= 64)
os << "." << access[i / 32];
// fp4_e2_32_t
if (t.lanes() >= 32)
os << "." << access[(i % 32) / 16];
// fp4_e2_16_t
if (t.lanes() >= 16)
os << "." << access[(i % 16) / 8];
// fp4_e2_8_t
if (t.lanes() >= 8)
os << "." << access[(i % 8) / 4];
// fp4_e2_4_t or fp4_e2_2_t
os << "." << access[i % 4];
} else if (t.lanes() > 4 && t.lanes() <= 8) { } else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name; std::string type_name;
if (t.bits() == 16) { if (t.bits() == 16) {
...@@ -810,6 +840,22 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t, ...@@ -810,6 +840,22 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t,
ICHECK(!type_name.empty()); ICHECK(!type_name.empty());
stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2] stream << "((" << type_name << "2*)(&(" << vec << "." << access[i / 2]
<< ")))->" << access[i % 2] << " = " << value << ";\n"; << ")))->" << access[i % 2] << " = " << value << ";\n";
} else if (t.is_float4_e2m1fn()) {
stream << vec;
// fp4_e2_64_t
if (t.lanes() >= 64)
stream << "." << access[i / 32];
// fp4_e2_32_t
if (t.lanes() >= 32)
stream << "." << access[(i % 32) / 16];
// fp4_e2_16_t
if (t.lanes() >= 16)
stream << "." << access[(i % 16) / 8];
// fp4_e2_8_t
if (t.lanes() >= 8)
stream << "." << access[(i % 8) / 4];
// fp4_e2_4_t or fp4_e2_2_t
stream << "." << access[i % 4] << " = " << value << ";\n";
} else { } else {
stream << vec << "." << access[i] << " = " << value << ";\n"; stream << vec << "." << access[i] << " = " << value << ";\n";
} }
...@@ -1365,7 +1411,7 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, ...@@ -1365,7 +1411,7 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
return os.str(); return os.str();
} }
std::string index_str = PrintExpr(index); std::string index_str = PrintExpr(index);
if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) { if ((t.bits() == 4 && !t.is_float4()) || (t.bits() == 1 && t.is_int())) {
// This is a special case, because CodegenCUDA::PrintType() // This is a special case, because CodegenCUDA::PrintType()
// returns "int" for bool and for 4-bit integers. In most cases, // returns "int" for bool and for 4-bit integers. In most cases,
// we divide by the number of lanes to determine the index. // we divide by the number of lanes to determine the index.
...@@ -2895,7 +2941,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op, ...@@ -2895,7 +2941,12 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op,
} else { } else {
bool can_vector_load = false; bool can_vector_load = false;
arith::PVar<PrimExpr> base; arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) { // For sub-byte types with lanes > 1 in element_dtype, adjust the ramp
// pattern
int ramp_lanes = (element_dtype.lanes() > 1 && element_dtype.bits() < 8)
? value_dtype.lanes() / element_dtype.lanes()
: value_dtype.lanes();
if (arith::ramp(base, 1, ramp_lanes).Match(index)) {
const RampNode *ramp = index.as<RampNode>(); const RampNode *ramp = index.as<RampNode>();
ICHECK(ramp); ICHECK(ramp);
can_vector_load = true; can_vector_load = true;
...@@ -2907,11 +2958,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op, ...@@ -2907,11 +2958,6 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op,
// } // }
} }
if (value_dtype.is_float4_e2m1fn() && lanes != 1) {
// A float4_e2m1fn element has 4 bits, which is an incomplete byte.
// So we cannot vector load it.
can_vector_load = false;
}
if (can_vector_load) { if (can_vector_load) {
std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval()); std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
HandleVolatileLoads(ref, op, os); HandleVolatileLoads(ref, op, os);
...@@ -2945,6 +2991,69 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op, ...@@ -2945,6 +2991,69 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op,
} }
} }
void CodeGenTileLangCUDA::VisitStmt_(const BufferStoreNode *op) {
ICHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported.";
ICHECK(!op->predicate.defined())
<< "Predicated buffer store is not supported.";
DataType value_dtype = op->value.dtype();
DataType element_dtype = op->buffer->dtype;
PrimExpr index_expr = op->indices[0];
Var buffer_var = op->buffer->data;
if (value_dtype.lanes() == element_dtype.lanes()) {
std::string value = this->PrintExpr(op->value);
std::string ref =
this->GetBufferRef(value_dtype, op->buffer.get(), index_expr);
this->PrintIndent();
stream << ref << " = " << value << ";\n";
} else {
arith::PVar<PrimExpr> base;
// For sub-byte types with lanes > 1 in element_dtype, adjust the ramp
// pattern
int ramp_lanes = (element_dtype.lanes() > 1 && element_dtype.bits() < 8)
? value_dtype.lanes() / element_dtype.lanes()
: value_dtype.lanes();
if (arith::ramp(base, 1, ramp_lanes).Match(index_expr)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer.get(), value_dtype, base.Eval(), value);
} else {
// The assignment below introduces side-effect, and the resulting value
// cannot be reused across multiple expression, thus a new scope is needed
int vec_scope = BeginScope();
// store elements separately
std::string index = SSAGetID(PrintExpr(index_expr), index_expr.dtype());
std::string value = SSAGetID(PrintExpr(op->value), op->value.dtype());
std::string vid = GetVarID(buffer_var.get());
for (int i = 0; i < value_dtype.lanes(); ++i) {
this->PrintIndent();
DataType elem_type = value_dtype.element_of();
if (!HandleTypeMatch(buffer_var.get(), elem_type)) {
stream << "((";
if (buffer_var.get()->dtype.is_handle()) {
auto it = alloc_storage_scope_.find(buffer_var.get());
if (it != alloc_storage_scope_.end()) {
PrintStorageScope(it->second, stream);
}
}
PrintType(elem_type, stream);
stream << "*)" << vid << ')';
} else {
stream << vid;
}
stream << '[';
PrintVecElemLoad(index, index_expr.dtype(), i, stream);
stream << "] = ";
PrintVecElemLoad(value, op->value.dtype(), i, stream);
stream << ";\n";
}
EndScope(vec_scope);
}
}
}
void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*) std::ostream &os) { // NOLINT(*)
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value); int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
......
...@@ -57,6 +57,7 @@ public: ...@@ -57,6 +57,7 @@ public:
void VisitStmt_(const AllocateNode *op) final; void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AttrStmtNode *op) final; void VisitStmt_(const AttrStmtNode *op) final;
void VisitExpr_(const BufferLoadNode *op, std::ostream &os) final; void VisitExpr_(const BufferLoadNode *op, std::ostream &os) final;
void VisitStmt_(const BufferStoreNode *op) final;
// Override this as a work around for __grid_constant__ parameter // Override this as a work around for __grid_constant__ parameter
void AddFunction(const GlobalVar &gvar, const PrimFunc &f); void AddFunction(const GlobalVar &gvar, const PrimFunc &f);
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
namespace tl { namespace tl {
// 256-bit load for longlong4
__device__ __forceinline__ longlong4 ld_global_256(const longlong4 *ptr) { __device__ __forceinline__ longlong4 ld_global_256(const longlong4 *ptr) {
longlong4 ret; longlong4 ret;
asm volatile("ld.global.v4.s64 {%0, %1, %2, %3}, [%4];" asm volatile("ld.global.v4.s64 {%0, %1, %2, %3}, [%4];"
...@@ -13,13 +14,18 @@ __device__ __forceinline__ longlong4 ld_global_256(const longlong4 *ptr) { ...@@ -13,13 +14,18 @@ __device__ __forceinline__ longlong4 ld_global_256(const longlong4 *ptr) {
return ret; return ret;
} }
__device__ __forceinline__ void st_global_256(longlong4 *ptr, longlong4 &val) { // 256-bit load for ulonglong4
asm volatile("st.global.v4.s64 [%0], {%1, %2, %3, %4};" __device__ __forceinline__ ulonglong4 ld_global_256(const ulonglong4 *ptr) {
: ulonglong4 ret;
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
: "l"(ptr));
return ret;
} }
__device__ __forceinline__ ulonglong4 ld_global_256(const ulonglong4 *ptr) { // Generic 256-bit load for FP8 types (returns ulonglong4)
template <typename T>
__device__ __forceinline__ ulonglong4 ld_global_256(const T *ptr) {
ulonglong4 ret; ulonglong4 ret;
asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];" asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w) : "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
...@@ -27,6 +33,22 @@ __device__ __forceinline__ ulonglong4 ld_global_256(const ulonglong4 *ptr) { ...@@ -27,6 +33,22 @@ __device__ __forceinline__ ulonglong4 ld_global_256(const ulonglong4 *ptr) {
return ret; return ret;
} }
// 256-bit store for longlong4
__device__ __forceinline__ void st_global_256(longlong4 *ptr, longlong4 &val) {
asm volatile("st.global.v4.s64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
// 256-bit store for ulonglong4 with non-const reference
__device__ __forceinline__ void st_global_256(ulonglong4 *ptr,
ulonglong4 &val) {
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
// 256-bit store for ulonglong4 with const reference
// must be const &val, otherwise the compiler will generate a temporary variable // must be const &val, otherwise the compiler will generate a temporary variable
// and compilation will fail if we have st_global_256(ptr, ld_global_256(ptr)) // and compilation will fail if we have st_global_256(ptr, ld_global_256(ptr))
__device__ __forceinline__ void st_global_256(ulonglong4 *ptr, __device__ __forceinline__ void st_global_256(ulonglong4 *ptr,
...@@ -36,35 +58,22 @@ __device__ __forceinline__ void st_global_256(ulonglong4 *ptr, ...@@ -36,35 +58,22 @@ __device__ __forceinline__ void st_global_256(ulonglong4 *ptr,
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
} }
__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e4_32_t *ptr) { // Generic 256-bit store for FP8 types
ulonglong4 ret; template <typename T>
asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];" __device__ __forceinline__ void st_global_256(T *ptr, const ulonglong4 &val) {
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
: "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr,
fp8_e4_32_t &val8) {
ulonglong4 &val = *((ulonglong4 *)&val8);
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};" asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
: :
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); : "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
} }
__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e5_32_t *ptr) {
ulonglong4 ret;
asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
: "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_global_256(fp8_e5_32_t *ptr, // Generic 256-bit store for FP8 types with non-const reference
fp8_e5_32_t &val8) { template <typename T>
ulonglong4 &val = *((ulonglong4 *)&val8); __device__ __forceinline__ void st_global_256(T *ptr, T &val) {
ulonglong4 &val_u64 = *((ulonglong4 *)&val);
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};" asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
: :
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w)); : "l"(ptr), "l"(val_u64.x), "l"(val_u64.y), "l"(val_u64.z),
"l"(val_u64.w));
} }
__device__ __forceinline__ unsigned long long __device__ __forceinline__ unsigned long long
......
#pragma once
#include "common.h"
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
#include <cuda_fp4.h>
// Wrapper for __nv_fp4_e2m1 with implicit conversions
struct fp4_e2_t {
__nv_fp4_storage_t __x;
TL_DEVICE fp4_e2_t() = default;
// Constructor from __nv_fp4_e2m1
TL_DEVICE fp4_e2_t(__nv_fp4_e2m1 x) : __x(x.__x) {}
// Constructor from storage type
TL_DEVICE fp4_e2_t(__nv_fp4_storage_t x) : __x(x) {}
// Constructor from float
TL_DEVICE explicit fp4_e2_t(float x) {
__nv_fp4_e2m1 tmp(x);
__x = tmp.__x;
}
// Conversion to __nv_fp4_e2m1
TL_DEVICE operator __nv_fp4_e2m1() const {
__nv_fp4_e2m1 tmp;
tmp.__x = __x;
return tmp;
}
// Conversion to float
TL_DEVICE operator float() const {
__nv_fp4_e2m1 tmp;
tmp.__x = __x;
return float(tmp);
}
// Implicit conversion to half_t (cutlass::half_t)
TL_DEVICE operator half_t() const { return half_t(float(*this)); }
// Implicit conversion to __half
TL_DEVICE operator __half() const { return __half(float(*this)); }
};
using fp4_e2x2_t = __nv_fp4x2_e2m1;
using fp4_e2x4_t = __nv_fp4x4_e2m1;
struct fp4_e2x8_t {
fp4_e2_t data[8];
};
struct fp4_e2x16_t {
fp4_e2_t data[16];
};
struct __CUDA_ALIGN__(1) fp4_e2_2_t {
fp4_e2_t x;
fp4_e2_t y;
};
struct __CUDA_ALIGN__(2) fp4_e2_4_t {
fp4_e2_t x;
fp4_e2_t y;
fp4_e2_t z;
fp4_e2_t w;
};
struct __CUDA_ALIGN__(4) fp4_e2_8_t {
fp4_e2_4_t x;
fp4_e2_4_t y;
};
struct __CUDA_ALIGN__(8) fp4_e2_16_t {
fp4_e2_8_t x;
fp4_e2_8_t y;
};
struct __CUDA_ALIGN__(16) fp4_e2_32_t {
fp4_e2_16_t x;
fp4_e2_16_t y;
TL_DEVICE fp4_e2_32_t &operator=(const ulonglong4 &rhs) {
x.x = *(fp4_e2_8_t *)&rhs.x;
x.y = *(fp4_e2_8_t *)&rhs.y;
y.x = *(fp4_e2_8_t *)&rhs.z;
y.y = *(fp4_e2_8_t *)&rhs.w;
return *this;
}
};
struct __CUDA_ALIGN__(32) fp4_e2_64_t {
fp4_e2_32_t x;
fp4_e2_32_t y;
};
// Pack two fp4_e2_t values.
TL_DEVICE fp4_e2_2_t make_fp4_e2_2_t(fp4_e2_t x, fp4_e2_t y) {
fp4_e2_2_t result;
result.x = x;
result.y = y;
return result;
}
// Pack four fp4_e2_t values.
TL_DEVICE fp4_e2_4_t make_fp4_e2_4_t(fp4_e2_t x0, fp4_e2_t x1, fp4_e2_t x2,
fp4_e2_t x3) {
fp4_e2_4_t result;
result.x = x0;
result.y = x1;
result.z = x2;
result.w = x3;
return result;
}
// Pack eight fp4_e2_t values.
TL_DEVICE fp4_e2_8_t make_fp4_e2_8_t(fp4_e2_t x0, fp4_e2_t x1, fp4_e2_t x2,
fp4_e2_t x3, fp4_e2_t x4, fp4_e2_t x5,
fp4_e2_t x6, fp4_e2_t x7) {
fp4_e2_8_t result;
result.x = make_fp4_e2_4_t(x0, x1, x2, x3);
result.y = make_fp4_e2_4_t(x4, x5, x6, x7);
return result;
}
// Pack sixteen fp4_e2_t values.
TL_DEVICE fp4_e2_16_t make_fp4_e2_16_t(fp4_e2_t x0, fp4_e2_t x1, fp4_e2_t x2,
fp4_e2_t x3, fp4_e2_t x4, fp4_e2_t x5,
fp4_e2_t x6, fp4_e2_t x7, fp4_e2_t y0,
fp4_e2_t y1, fp4_e2_t y2, fp4_e2_t y3,
fp4_e2_t y4, fp4_e2_t y5, fp4_e2_t y6,
fp4_e2_t y7) {
fp4_e2_16_t result;
result.x = make_fp4_e2_8_t(x0, x1, x2, x3, x4, x5, x6, x7);
result.y = make_fp4_e2_8_t(y0, y1, y2, y3, y4, y5, y6, y7);
return result;
}
// Pack thirty-two fp4_e2_t values.
TL_DEVICE fp4_e2_32_t make_fp4_e2_32_t(
fp4_e2_t x0, fp4_e2_t x1, fp4_e2_t x2, fp4_e2_t x3, fp4_e2_t x4,
fp4_e2_t x5, fp4_e2_t x6, fp4_e2_t x7, fp4_e2_t x8, fp4_e2_t x9,
fp4_e2_t x10, fp4_e2_t x11, fp4_e2_t x12, fp4_e2_t x13, fp4_e2_t x14,
fp4_e2_t x15, fp4_e2_t y0, fp4_e2_t y1, fp4_e2_t y2, fp4_e2_t y3,
fp4_e2_t y4, fp4_e2_t y5, fp4_e2_t y6, fp4_e2_t y7, fp4_e2_t y8,
fp4_e2_t y9, fp4_e2_t y10, fp4_e2_t y11, fp4_e2_t y12, fp4_e2_t y13,
fp4_e2_t y14, fp4_e2_t y15) {
fp4_e2_32_t result;
result.x = make_fp4_e2_16_t(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11,
x12, x13, x14, x15);
result.y = make_fp4_e2_16_t(y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11,
y12, y13, y14, y15);
return result;
}
#endif
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
using fp8_e4_t = tl::float_e4m3_t; using fp8_e4_t = tl::float_e4m3_t;
using fp8_e5_t = tl::float_e5m2_t; using fp8_e5_t = tl::float_e5m2_t;
using fp8_e8_t = __nv_fp8_e8m0;
struct __CUDA_ALIGN__(2) fp8_e4_2_t { struct __CUDA_ALIGN__(2) fp8_e4_2_t {
fp8_e4_t x; fp8_e4_t x;
...@@ -77,6 +78,41 @@ struct __CUDA_ALIGN__(32) fp8_e5_32_t { ...@@ -77,6 +78,41 @@ struct __CUDA_ALIGN__(32) fp8_e5_32_t {
} }
}; };
struct __CUDA_ALIGN__(2) fp8_e8_2_t {
fp8_e8_t x;
fp8_e8_t y;
};
struct __CUDA_ALIGN__(4) fp8_e8_4_t {
fp8_e8_t x;
fp8_e8_t y;
fp8_e8_t z;
fp8_e8_t w;
};
struct __CUDA_ALIGN__(8) fp8_e8_8_t {
fp8_e8_4_t x;
fp8_e8_4_t y;
};
struct __CUDA_ALIGN__(16) fp8_e8_16_t {
fp8_e8_8_t x;
fp8_e8_8_t y;
};
struct __CUDA_ALIGN__(32) fp8_e8_32_t {
fp8_e8_16_t x;
fp8_e8_16_t y;
TL_DEVICE fp8_e8_32_t &operator=(const ulonglong4 &rhs) {
x.x = *(fp8_e8_8_t *)&rhs.x;
x.y = *(fp8_e8_8_t *)&rhs.y;
y.x = *(fp8_e8_8_t *)&rhs.z;
y.y = *(fp8_e8_8_t *)&rhs.w;
return *this;
}
};
// Pack two fp8_e4_t values. // Pack two fp8_e4_t values.
TL_DEVICE fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) { TL_DEVICE fp8_e4_2_t make_fp8_e4_2_t(fp8_e4_t x, fp8_e4_t y) {
fp8_e4_2_t result; fp8_e4_2_t result;
...@@ -195,6 +231,65 @@ TL_DEVICE fp8_e5_32_t make_fp8_e5_32_t( ...@@ -195,6 +231,65 @@ TL_DEVICE fp8_e5_32_t make_fp8_e5_32_t(
return result; return result;
} }
// Pack two fp8_e8_t values.
TL_DEVICE fp8_e8_2_t make_fp8_e8_2_t(fp8_e8_t x, fp8_e8_t y) {
fp8_e8_2_t result;
result.x = x;
result.y = y;
return result;
}
// Pack four fp8_e8_t values.
TL_DEVICE fp8_e8_4_t make_fp8_e8_4_t(fp8_e8_t x0, fp8_e8_t x1, fp8_e8_t x2,
fp8_e8_t x3) {
fp8_e8_4_t result;
result.x = x0;
result.y = x1;
result.z = x2;
result.w = x3;
return result;
}
// Pack eight fp8_e8_t values.
TL_DEVICE fp8_e8_8_t make_fp8_e8_8_t(fp8_e8_t x0, fp8_e8_t x1, fp8_e8_t x2,
fp8_e8_t x3, fp8_e8_t x4, fp8_e8_t x5,
fp8_e8_t x6, fp8_e8_t x7) {
fp8_e8_8_t result;
result.x = make_fp8_e8_4_t(x0, x1, x2, x3);
result.y = make_fp8_e8_4_t(x4, x5, x6, x7);
return result;
}
// Pack sixteen fp8_e8_t values.
TL_DEVICE fp8_e8_16_t make_fp8_e8_16_t(fp8_e8_t x0, fp8_e8_t x1, fp8_e8_t x2,
fp8_e8_t x3, fp8_e8_t x4, fp8_e8_t x5,
fp8_e8_t x6, fp8_e8_t x7, fp8_e8_t y0,
fp8_e8_t y1, fp8_e8_t y2, fp8_e8_t y3,
fp8_e8_t y4, fp8_e8_t y5, fp8_e8_t y6,
fp8_e8_t y7) {
fp8_e8_16_t result;
result.x = make_fp8_e8_8_t(x0, x1, x2, x3, x4, x5, x6, x7);
result.y = make_fp8_e8_8_t(y0, y1, y2, y3, y4, y5, y6, y7);
return result;
}
// Pack thirty-two fp8_e8_t values.
TL_DEVICE fp8_e8_32_t make_fp8_e8_32_t(
fp8_e8_t x0, fp8_e8_t x1, fp8_e8_t x2, fp8_e8_t x3, fp8_e8_t x4,
fp8_e8_t x5, fp8_e8_t x6, fp8_e8_t x7, fp8_e8_t x8, fp8_e8_t x9,
fp8_e8_t x10, fp8_e8_t x11, fp8_e8_t x12, fp8_e8_t x13, fp8_e8_t x14,
fp8_e8_t x15, fp8_e8_t y0, fp8_e8_t y1, fp8_e8_t y2, fp8_e8_t y3,
fp8_e8_t y4, fp8_e8_t y5, fp8_e8_t y6, fp8_e8_t y7, fp8_e8_t y8,
fp8_e8_t y9, fp8_e8_t y10, fp8_e8_t y11, fp8_e8_t y12, fp8_e8_t y13,
fp8_e8_t y14, fp8_e8_t y15) {
fp8_e8_32_t result;
result.x = make_fp8_e8_16_t(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11,
x12, x13, x14, x15);
result.y = make_fp8_e8_16_t(y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11,
y12, y13, y14, y15);
return result;
}
// e4m3x2 -> float2 // e4m3x2 -> float2
TL_DEVICE float2 TL_DEVICE float2
__tl_cvt_fp8x2_to_float2(const __nv_fp8x2_storage_t x, __tl_cvt_fp8x2_to_float2(const __nv_fp8x2_storage_t x,
......
...@@ -443,9 +443,21 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type, ...@@ -443,9 +443,21 @@ void ArgBinder::BindDLTensor(const Buffer &buffer, const PrimExpr &device_type,
PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok); PrimExpr bit1_ok = (v_type_bits == bits1 && lanes_ok);
cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok; cond = cond || int8_ok || uint8_ok || kdlbool8_ok || kdlbool1_ok || bit1_ok;
} }
// Allow float4 to match int8 at runtime (PyTorch uses int8 as storage for
// FP4).
if (buffer->dtype.is_float4()) {
PrimExpr code_int = IntImm(DataType::UInt(8), DataType::kInt);
PrimExpr bits8 = IntImm(DataType::UInt(8), 8);
// For FP4, we pack 2 elements per byte, but we still use same lanes at
// storage level Accept int8 with same lanes as the fp4 type
PrimExpr fp4_lanes_ok = (v_type_lanes == expect_lanes);
PrimExpr int8_ok =
(v_type_code == code_int && v_type_bits == bits8 && fp4_lanes_ok);
cond = cond || int8_ok;
}
if (!(buffer->dtype == DataType::Int(1) || if (!(buffer->dtype == DataType::Int(1) ||
buffer->dtype == DataType::Int(4) || buffer->dtype == DataType::Int(4) ||
buffer->dtype == DataType::UInt(4))) { buffer->dtype == DataType::UInt(4) || buffer->dtype.is_float4())) {
// Build FFI packed call to __tvm_error_dtype_mismatch when mismatch occurs. // Build FFI packed call to __tvm_error_dtype_mismatch when mismatch occurs.
// Only issue the call when handle is non-NULL and cond is false. // Only issue the call when handle is non-NULL and cond is false.
ffi::Array<PrimExpr> packed_args; ffi::Array<PrimExpr> packed_args;
......
...@@ -192,7 +192,8 @@ private: ...@@ -192,7 +192,8 @@ private:
vector_size_, analyzer_)) { vector_size_, analyzer_)) {
// If not, tight vectorize bound with buffer dtype constraint // If not, tight vectorize bound with buffer dtype constraint
vector_size_ = arith::ZeroAwareGCD( vector_size_ = arith::ZeroAwareGCD(
vector_size_, vector_load_bits_max_ / buffer->dtype.bits()); vector_size_, vector_load_bits_max_ /
(buffer->dtype.bits() * buffer->dtype.lanes()));
} }
// 4. Try to vectorize buffer load // 4. Try to vectorize buffer load
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
......
from tilelang import tvm as tvm
import tilelang.testing
from tilelang.cache import cached
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
"""
Defines a matrix multiplication primitive function using tilelang.
This function constructs a tilelang primitive function for matrix multiplication,
optimized for execution on hardware accelerators. It utilizes shared memory and
fragment memory for performance.
Args:
M (int): Number of rows in matrix A and C.
N (int): Number of columns in matrix B and C.
K (int): Number of columns in matrix A and rows in matrix B.
block_M (int): Block size for M dimension in shared memory and fragment.
block_N (int): Block size for N dimension in shared memory and fragment.
block_K (int): Block size for K dimension in shared memory.
dtype (str, optional): Data type for input matrices A and B, and output C. Defaults to "float16".
accum_dtype (str, optional): Accumulation data type for internal computations. Defaults to "float".
Returns:
T.PrimFunc: A tilelang primitive function representing the matrix multiplication.
"""
@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((K, N), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[k * block_K, bx * block_N], B_shared)
T.gemm(A_shared, B_shared, C_local)
T.copy(C_local, C[by * block_M, bx * block_N])
return main
def run_cache_matmul():
"""
Demonstrates the usage of the cached matrix multiplication kernel.
This function defines a reference PyTorch matrix multiplication,
creates a cached kernel from the tilelang matmul function,
runs the kernel with random input tensors, compares the output with the reference,
and prints the CUDA kernel source code.
"""
def ref_program(A, B):
"""
Reference PyTorch matrix multiplication for comparison.
"""
import torch
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.half) # Assuming dtype="float16" in matmul
return C
func = matmul(1024, 1024, 1024, 128, 128, 32)
kernel = cached(func, [2], execution_backend="cython")
import torch
a = torch.randn(1024, 1024).cuda().half()
b = torch.randn(1024, 1024).cuda().half()
c = kernel(a, b)
print("\nOutput from Cached Kernel:")
print(c)
ref_c = ref_program(a, b)
print("\nReference PyTorch Output:")
print(ref_c)
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)
print("\nOutputs are close (within tolerance).")
# Get CUDA Source
print("\nCUDA Kernel Source:")
print(kernel.get_kernel_source())
def test_cache_matmul_f16f16f16_nn():
"""
Test function for cached matrix multiplication (float16 inputs, float16 output, no transpose).
"""
run_cache_matmul()
if __name__ == "__main__":
tilelang.testing.main()
...@@ -3,31 +3,37 @@ import tilelang.language as T ...@@ -3,31 +3,37 @@ import tilelang.language as T
import torch import torch
import tilelang.testing import tilelang.testing
print(torch.__version__)
# add decorator @tilelang.jit if you want to return a torch function # add decorator @tilelang.jit if you want to return a torch function
# @tilelang.jit # @tilelang.jit
def tilelang_copy(M, N, block_M, block_N, dtype="float16"): def tilelang_copy(M, N, block_M, block_N, src_dtype="float16", dst_dtype="float16"):
@T.prim_func @T.prim_func
def main( def main(
A: T.Tensor((M, N), dtype), A: T.Tensor((M, N), src_dtype),
B: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dst_dtype),
): ):
# Initialize Kernel Context # Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
for i, j in T.Parallel(block_M, block_N): T.copy(
B[by * block_M + i, bx * block_N + j] = A[by * block_M + i, bx * block_N + j] A[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N],
B[by * block_M : (by + 1) * block_M, bx * block_N : (bx + 1) * block_N],
)
return main return main
def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"): def run_tilelang_copy(M=1024, N=1024, block_M=128, block_N=128, dtype="float16"):
program = tilelang_copy(M, N, block_M, block_N, dtype) program = tilelang_copy(M, N, block_M, block_N, src_dtype=dtype, dst_dtype=dtype)
kernel = tilelang.compile( kernel = tilelang.compile(
program, program,
out_idx=[1], out_idx=[1],
target="cuda", target="cuda",
pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True}, pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True},
) )
source = kernel.get_kernel_source()
print(source)
a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype)) a = torch.randn(M, N, device="cuda", dtype=getattr(torch, dtype))
b = kernel(a) b = kernel(a)
torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2) torch.testing.assert_close(b, a, rtol=1e-2, atol=1e-2)
...@@ -137,5 +143,46 @@ def test_tilelang_copy_buffer_load_with_parallel(): ...@@ -137,5 +143,46 @@ def test_tilelang_copy_buffer_load_with_parallel():
run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128) run_tilelang_copy_buffer_load_with_parallel(M=1024, N=1024, block_M=128, block_N=128)
def run_tilelang_copy_fp8_e8m0(M=1024, N=1024, block_M=128, block_N=128, src_dtype="float8_e8m0fnu", dst_dtype="float8_e8m0fnu"):
program = tilelang_copy(M, N, block_M, block_N, src_dtype=src_dtype, dst_dtype=dst_dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
)
source = kernel.get_kernel_source()
assert "fp8_e8_t" in source
dummy_input = torch.randint(0, 100, (M, N), device="cuda", dtype=torch.int8).view(torch.float8_e8m0fnu)
output = kernel(dummy_input)
assert output is not None
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(10, 0)
def test_tilelang_copy_fp8_e8m0():
run_tilelang_copy_fp8_e8m0(src_dtype="float8_e8m0fnu", dst_dtype="float8_e8m0fnu")
def run_tilelang_copy_fp4(M=1024, N=1024, block_M=128, block_N=128, src_dtype="float4_e2m1fn", dst_dtype="float4_e2m1fn"):
program = tilelang_copy(M, N, block_M, block_N, src_dtype=src_dtype, dst_dtype=dst_dtype)
kernel = tilelang.compile(
program,
out_idx=[1],
)
source = kernel.get_kernel_source()
assert "fp4_e2_t" in source
# For FP4, use same shape as kernel expects, since int8 is used as storage type
dummy_input = torch.randint(0, 100, (M, N), device="cuda", dtype=torch.int8)
output = kernel(dummy_input)
assert output is not None
@tilelang.testing.requires_cuda
@tilelang.testing.requires_cuda_compute_version_ge(10, 0)
def test_tilelang_copy_fp4():
run_tilelang_copy_fp4(src_dtype="float4_e2m1fn", dst_dtype="float4_e2m1fn")
run_tilelang_copy_fp4(src_dtype="float4_e2m1fn", dst_dtype="float16")
run_tilelang_copy_fp4(src_dtype="float4_e2m1fn", dst_dtype="bfloat16")
if __name__ == "__main__": if __name__ == "__main__":
tilelang.testing.main() tilelang.testing.main()
...@@ -16,7 +16,11 @@ class KernelParam: ...@@ -16,7 +16,11 @@ class KernelParam:
Used to describe tensor or scalar parameters in TVM/PyTorch interop. Used to describe tensor or scalar parameters in TVM/PyTorch interop.
""" """
dtype: torch.dtype # PyTorch data type of the parameter # Use tvm.DataType (buffer.dtype) directly instead of torch.dtype to support more data types
# tvm.DataType can represent a much wider range of types than PyTorch's dtype system,
# including specialized types like float8_e4m3, float8_e5m2, custom quantized types, etc.
# This avoids information loss when converting from TVM buffer types
dtype: tvm.DataType # Data type from buffer.dtype (supports all TVM types)
shape: list[int | Var] # List of dimensions, can be integers or TVM variables shape: list[int | Var] # List of dimensions, can be integers or TVM variables
@classmethod @classmethod
...@@ -28,12 +32,14 @@ class KernelParam: ...@@ -28,12 +32,14 @@ class KernelParam:
buffer: TVM Buffer object containing dtype and shape information buffer: TVM Buffer object containing dtype and shape information
Returns: Returns:
KernelParam instance with converted dtype and shape KernelParam instance with dtype directly from buffer and shape
Raises: Raises:
ValueError: If dimension type is not supported (not IntImm or Var) ValueError: If dimension type is not supported (not IntImm or Var)
""" """
dtype = map_torch_type(buffer.dtype) # Use buffer.dtype directly (tvm.DataType) to preserve all type information
# buffer.dtype is already a tvm.DataType object, no conversion needed
dtype = buffer.dtype
shape = [] shape = []
for s in buffer.shape: for s in buffer.shape:
if isinstance(s, IntImm): if isinstance(s, IntImm):
...@@ -56,7 +62,9 @@ class KernelParam: ...@@ -56,7 +62,9 @@ class KernelParam:
Returns: Returns:
KernelParam instance representing a scalar (empty shape) KernelParam instance representing a scalar (empty shape)
""" """
dtype = map_torch_type(var.dtype) # Use var.dtype directly (tvm.DataType) to preserve all type information
# var.dtype is already a tvm.DataType object, no conversion needed
dtype = var.dtype
return cls(dtype, []) return cls(dtype, [])
def is_scalar(self) -> bool: def is_scalar(self) -> bool:
...@@ -92,6 +100,18 @@ class KernelParam: ...@@ -92,6 +100,18 @@ class KernelParam:
dtype_str = dtype_str[6:] dtype_str = dtype_str[6:]
return dtype_str.startswith("float8") return dtype_str.startswith("float8")
def is_float4(self) -> bool:
"""
Checks if the parameter represents a float4 type.
Returns:
bool: True if parameter is a float4 type, False otherwise
"""
dtype_str = str(self.dtype)
if dtype_str.startswith("torch."):
dtype_str = dtype_str[6:]
return dtype_str.startswith("float4")
def is_boolean(self) -> bool: def is_boolean(self) -> bool:
""" """
Checks if the parameter represents a boolean type. Checks if the parameter represents a boolean type.
...@@ -104,6 +124,22 @@ class KernelParam: ...@@ -104,6 +124,22 @@ class KernelParam:
dtype_str = dtype_str[6:] dtype_str = dtype_str[6:]
return dtype_str.startswith("bool") return dtype_str.startswith("bool")
def torch_dtype(self) -> torch.dtype:
"""
Converts the TVM DataType to PyTorch dtype.
This method is used when creating PyTorch tensors from KernelParam,
as PyTorch's tensor creation functions require torch.dtype.
Returns:
torch.dtype: Corresponding PyTorch dtype
Example:
>>> param = KernelParam.from_buffer(buffer)
>>> tensor = torch.empty(shape, dtype=param.torch_dtype())
"""
return map_torch_type(str(self.dtype))
@dataclass @dataclass
class CompiledArtifact: class CompiledArtifact:
......
...@@ -76,7 +76,8 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -76,7 +76,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
self.ir_module = func_or_mod self.ir_module = func_or_mod
# Cache parameter information during initialization # Cache parameter information during initialization
self.param_dtypes = [param.dtype for param in params] # Convert tvm.DataType to torch.dtype for tensor creation
self.param_dtypes = [param.torch_dtype() for param in params]
self.param_shapes = [] self.param_shapes = []
for param in params: for param in params:
native_shape = [] native_shape = []
...@@ -139,7 +140,8 @@ class CtypesKernelAdapter(BaseKernelAdapter): ...@@ -139,7 +140,8 @@ class CtypesKernelAdapter(BaseKernelAdapter):
adapter.ir_module = func_or_mod adapter.ir_module = func_or_mod
# Cache parameter information during initialization # Cache parameter information during initialization
adapter.param_dtypes = [param.dtype for param in params] # Convert tvm.DataType to torch.dtype for tensor creation
adapter.param_dtypes = [param.torch_dtype() for param in params]
adapter.param_shapes = [] adapter.param_shapes = []
for param in params: for param in params:
native_shape = [] native_shape = []
......
...@@ -32,7 +32,8 @@ cdef class CythonKernelWrapper: ...@@ -32,7 +32,8 @@ cdef class CythonKernelWrapper:
self.params = params self.params = params
self.lib = lib self.lib = lib
# Convert TVM types to native Python types during initialization # Convert TVM types to native Python types during initialization
self.param_dtypes = [param.dtype for param in params] # Convert tvm.DataType to torch.dtype for tensor creation
self.param_dtypes = [param.torch_dtype() for param in params]
# Convert TVM shape arrays to native Python lists # Convert TVM shape arrays to native Python lists
self.param_shapes = [] self.param_shapes = []
self.get_current_device = torch.cuda.current_device self.get_current_device = torch.cuda.current_device
......
...@@ -52,7 +52,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -52,7 +52,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
self.ir_module = func_or_mod self.ir_module = func_or_mod
# Cache parameter information during initialization # Cache parameter information during initialization
self.param_dtypes = [param.dtype for param in params] # Convert tvm.DataType to torch.dtype for tensor creation
self.param_dtypes = [param.torch_dtype() for param in params]
self.param_shapes = [] self.param_shapes = []
for param in params: for param in params:
native_shape = [] native_shape = []
...@@ -118,7 +119,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter): ...@@ -118,7 +119,8 @@ class NVRTCKernelAdapter(BaseKernelAdapter):
adapter.ir_module = func_or_mod adapter.ir_module = func_or_mod
# Cache parameter information during initialization # Cache parameter information during initialization
adapter.param_dtypes = [param.dtype for param in params] # Convert tvm.DataType to torch.dtype for tensor creation
adapter.param_dtypes = [param.torch_dtype() for param in params]
adapter.param_shapes = [] adapter.param_shapes = []
for param in params: for param in params:
native_shape = [] native_shape = []
......
...@@ -135,7 +135,8 @@ class TVMFFIKernelAdapter(BaseKernelAdapter): ...@@ -135,7 +135,8 @@ class TVMFFIKernelAdapter(BaseKernelAdapter):
current_device_functor = self.get_current_device_functor() current_device_functor = self.get_current_device_functor()
# Convert TVM types to native Python types during initialization # Convert TVM types to native Python types during initialization
param_dtypes = [param.dtype for param in self.params] # Convert tvm.DataType to torch.dtype for tensor creation
param_dtypes = [param.torch_dtype() for param in self.params]
# Convert TVM shape arrays to native Python lists # Convert TVM shape arrays to native Python lists
param_shapes = [] param_shapes = []
......
...@@ -32,7 +32,11 @@ class TensorSupplyType(Enum): ...@@ -32,7 +32,11 @@ class TensorSupplyType(Enum):
Auto = 7 Auto = 7
def map_torch_type(intype: str) -> torch.dtype: def map_torch_type(intype) -> torch.dtype:
# Convert to string if needed
if not isinstance(intype, str):
intype = str(intype)
if intype == "float8_e4m3": if intype == "float8_e4m3":
assert hasattr(torch, "float8_e4m3fn"), "torch.float8_e4m3fn is not supported in this version of torchPlease upgrade torch >= 2.1.0" assert hasattr(torch, "float8_e4m3fn"), "torch.float8_e4m3fn is not supported in this version of torchPlease upgrade torch >= 2.1.0"
return torch.float8_e4m3fn return torch.float8_e4m3fn
...@@ -44,6 +48,19 @@ def map_torch_type(intype: str) -> torch.dtype: ...@@ -44,6 +48,19 @@ def map_torch_type(intype: str) -> torch.dtype:
"torch.float8_e4m3fnuz is not supported in this version of torchPlease upgrade torch >= 2.2.0" "torch.float8_e4m3fnuz is not supported in this version of torchPlease upgrade torch >= 2.2.0"
) )
return torch.float8_e4m3fnuz return torch.float8_e4m3fnuz
elif intype == "float8_e8m0fnu":
assert hasattr(torch, "float8_e8m0fnu"), (
"torch.float8_e8m0fnu is not supported in this version of torchPlease upgrade torch >= 2.8.0"
)
return torch.float8_e8m0fnu
elif intype == "float4_e2m1fnx2":
assert hasattr(torch, "float4_e2m1fnx2"), (
"torch.float4_e2m1fnx2 is not supported in this version of torchPlease upgrade torch >= 2.8.0"
)
return torch.float4_e2m1fnx2
elif "float4" in intype:
# PyTorch doesn't support float4, use int8 as storage type
return torch.int8
else: else:
return getattr(torch, intype) return getattr(torch, intype)
...@@ -53,7 +70,8 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): ...@@ -53,7 +70,8 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
from .device import get_current_device from .device import get_current_device
def get_tensor(param: KernelParam) -> torch.Tensor: def get_tensor(param: KernelParam) -> torch.Tensor:
dtype: torch.dtype = param.dtype # Convert tvm.DataType to torch.dtype for tensor creation
dtype: torch.dtype = param.torch_dtype()
device = get_current_device() device = get_current_device()
if hasattr(param, "shape") and not param.shape: if hasattr(param, "shape") and not param.shape:
...@@ -74,11 +92,14 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): ...@@ -74,11 +92,14 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
if supply_type == TensorSupplyType.Auto: if supply_type == TensorSupplyType.Auto:
is_unsigned = param.is_unsigned() is_unsigned = param.is_unsigned()
is_float8 = param.is_float8() is_float8 = param.is_float8()
is_float4 = param.is_float4()
is_boolean = param.is_boolean() is_boolean = param.is_boolean()
if is_unsigned: if is_unsigned:
return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype) return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype)
elif is_float8: elif is_float8:
return torch.randint(low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype) return torch.randint(low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype)
elif is_float4:
return torch.randint(low=0, high=16, size=shape, device=device, dtype=dtype)
elif is_boolean: elif is_boolean:
return torch.randint(low=0, high=2, size=shape, device=device, dtype=dtype) return torch.randint(low=0, high=2, size=shape, device=device, dtype=dtype)
elif dtype in {torch.float16, torch.float32, torch.bfloat16}: elif dtype in {torch.float16, torch.float32, torch.bfloat16}:
...@@ -95,11 +116,14 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer): ...@@ -95,11 +116,14 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
if supply_type == TensorSupplyType.Integer: if supply_type == TensorSupplyType.Integer:
is_unsigned = param.is_unsigned() is_unsigned = param.is_unsigned()
is_float8 = param.is_float8() is_float8 = param.is_float8()
is_float4 = param.is_float4()
is_boolean = param.is_boolean() is_boolean = param.is_boolean()
if is_unsigned: if is_unsigned:
return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype) return torch.randint(low=0, high=3, size=shape, device=device, dtype=dtype)
elif is_float8: elif is_float8:
return torch.randint(low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype) return torch.randint(low=-128, high=128, size=shape, device=device, dtype=torch.int8).to(dtype)
elif is_float4:
return torch.randint(low=0, high=16, size=shape, device=device, dtype=dtype)
elif is_boolean: elif is_boolean:
return torch.randint(low=0, high=2, size=shape, device=device, dtype=dtype) return torch.randint(low=0, high=2, size=shape, device=device, dtype=dtype)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment