Unverified Commit f58bcd43 authored by Zhiwen Mo's avatar Zhiwen Mo Committed by GitHub
Browse files

[SM100] Add sm100 GEMM layouts and tcgen05 support (#887)

* update sm100 related utcmma, tmem, ld/st256 in src
* update sm100 related utcmma, tmem, ld/st256 in tilelang
* Remove deprecated GEMM examples and related README documentation for SM100 architecture support
* Update GEMM implementation to replace UTCMMA with TCGEN5MMA across relevant files
* Remove gemm_umma.py example and update README to reflect TCGEN5MMA terminology changes
* Update README.md for gemm_sm100 example by removing outdated API sections and streamlining documentation
* Update README and source files to reflect TCGEN5.MMA terminology changes
* Refactor CUDA GEMM header for improved readability
parent c382dcbc
...@@ -120,9 +120,12 @@ static std::string GetFP8Type(DataType type) { ...@@ -120,9 +120,12 @@ static std::string GetFP8Type(DataType type) {
vec = "_8"; vec = "_8";
} else if (lanes == 16) { } else if (lanes == 16) {
vec = "_16"; vec = "_16";
} else if (lanes == 32) {
vec = "_32";
} else { } else {
LOG(FATAL) << "Only support scalar and vector types of width (2, 4, 8, 16) " LOG(FATAL)
"for FP8"; << "Only support scalar and vector types of width (2, 4, 8, 16, 32) "
"for FP8";
} }
if (type.is_float8_e4m3fn() || type.is_float8_e4m3fnuz() || if (type.is_float8_e4m3fn() || type.is_float8_e4m3fnuz() ||
type.is_float8_e4m3()) { type.is_float8_e4m3()) {
...@@ -354,6 +357,10 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ...@@ -354,6 +357,10 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
// //
ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "uint" << lanes / 2; os << "uint" << lanes / 2;
} else if (lanes <= 16) {
ICHECK_EQ(lanes % 4, 0) << "only support (mod 4 = 0) lanes for half "
"type of more than 8 lanes";
os << "ulonglong" << lanes / 4;
} else { } else {
fail = true; fail = true;
} }
...@@ -398,6 +405,10 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ...@@ -398,6 +405,10 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
} else if (lanes <= 8) { } else if (lanes <= 8) {
ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type"; ICHECK_EQ(lanes % 2, 0) << "only support even lane for half type";
os << "uint" << lanes / 2; os << "uint" << lanes / 2;
} else if (lanes <= 16) {
ICHECK_EQ(lanes % 4, 0) << "only support (mod 4 = 0) lanes for half type "
"of more than 8 lanes";
os << "ulonglong" << lanes / 4;
} else { } else {
fail = true; fail = true;
} }
...@@ -494,6 +505,10 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ...@@ -494,6 +505,10 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
enable_int8_ = true; enable_int8_ = true;
os << "int4"; os << "int4";
return; return;
} else if (t.lanes() == 32) {
enable_int8_ = true;
os << "longlong4";
return;
} else if (!t.is_uint() && t.is_scalar()) { } else if (!t.is_uint() && t.is_scalar()) {
os << "signed char"; os << "signed char";
break; break;
...@@ -561,8 +576,13 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*) ...@@ -561,8 +576,13 @@ void CodeGenTileLangCUDA::PrintType(DataType t, std::ostream &os) { // NOLINT(*)
os << "longlong3"; os << "longlong3";
} else if (t.lanes() == 4) { } else if (t.lanes() == 4) {
os << "longlong4"; os << "longlong4";
} else {
fail = true;
} }
return; if (!fail) {
return;
}
break;
} }
default: default:
fail = true; fail = true;
...@@ -624,23 +644,48 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t, ...@@ -624,23 +644,48 @@ void CodeGenTileLangCUDA::PrintVecElemLoad(const std::string &vec, DataType t,
} }
static const char access[] = {'x', 'y', 'z', 'w'}; static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 ICHECK(i >= 0 && i < 256 / t.bits());
: (t.bits() == 16 || t.bits() == 32) ? 8
: 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
std::string type_name = t.is_int() ? "char" : "unsigned char"; std::string type_name = t.is_int() ? "char" : "unsigned char";
if (t.lanes() == 2 || t.lanes() == 3) { if (t.lanes() == 2 || t.lanes() == 3) {
os << vec << "." << access[i % t.lanes()]; os << vec << "." << access[i % t.lanes()];
} else { } else if (t.lanes() <= 16) {
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))"; os << "((" << type_name << ")(" << ac << " >> " << i % 4 * 8 << "))";
} else {
ICHECK(t.lanes() == 32);
std::string ac = vec + "." + access[i / 8];
os << "((" << type_name << ")(" << ac << " >> " << i % 8 * 8 << "))";
} }
} else if (t.is_float16()) { } else if (t.is_float16()) {
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" if (t.lanes() <= 8) {
<< access[i % 2]; os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2];
} else {
os << "(((half2*)(&(" << vec << "." << access[i / 4] << "))) + "
<< (i / 2 % 2) << ")->" << access[i % 2];
}
} else if (t.is_bfloat16()) { } else if (t.is_bfloat16()) {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" if (t.lanes() <= 8) {
<< access[i % 2]; os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2];
} else {
os << "(((nv_bfloat162*)(&(" << vec << "." << access[i / 4] << "))) + "
<< (i / 2 % 2) << ")->" << access[i % 2];
}
} else if (t.is_float8()) {
os << vec;
// fp8_e5_32_t
if (t.lanes() >= 32)
os << "." << access[i / 16];
// fp8_e5_16_t
if (t.lanes() >= 16)
os << "." << access[(i % 16) / 8];
// fp8_e5_8_t
if (t.lanes() >= 8)
os << "." << access[(i % 8) / 4];
// fp8_e5_4_t or fp8_e5_2_t
os << "." << access[i % 4];
} else if (t.lanes() > 4 && t.lanes() <= 8) { } else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name; std::string type_name;
if (t.bits() == 16) { if (t.bits() == 16) {
...@@ -670,14 +715,12 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t, ...@@ -670,14 +715,12 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t,
int i, const std::string &value) { int i, const std::string &value) {
this->PrintIndent(); this->PrintIndent();
static const char access[] = {'x', 'y', 'z', 'w'}; static const char access[] = {'x', 'y', 'z', 'w'};
ICHECK(i >= 0 && i < (t.bits() == 8 ? 16 ICHECK(i >= 0 && i < 256 / t.bits());
: (t.bits() == 16 || t.bits() == 32) ? 8
: 4));
if (t.bits() == 8 && (t.is_int() || t.is_uint())) { if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (t.lanes() == 2 || t.lanes() == 3) { if (t.lanes() == 2 || t.lanes() == 3) {
stream << vec << '.' << access[i % t.lanes()] << "=" stream << vec << '.' << access[i % t.lanes()] << "="
<< "(" << value << ");\n"; << "(" << value << ");\n";
} else { } else if (t.lanes() <= 16) {
std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]); std::string ac = t.lanes() == 4 ? vec : (vec + "." + access[i / 4]);
stream << ac << "="; stream << ac << "=";
// Do not read the first undef lane. // Do not read the first undef lane.
...@@ -685,13 +728,47 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t, ...@@ -685,13 +728,47 @@ void CodeGenTileLangCUDA::PrintVecElemStore(const std::string &vec, DataType t,
stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |"; stream << ac << " & ~(0x000000ff << " << i % 4 * 8 << ") |";
} }
stream << "(" << value << " << " << i % 4 * 8 << ");\n"; stream << "(" << value << " << " << i % 4 * 8 << ");\n";
} else {
ICHECK(t.lanes() == 32);
std::string ac = vec + "." + access[i / 8];
stream << ac << "=";
// Do not read the first undef lane.
if (i != 0) {
stream << ac << " & ~(0x000000ff << " << i % 8 * 8 << ") |";
}
stream << "(" << value << " << " << i % 8 * 8 << ");\n";
} }
} else if (t.is_float16()) { } else if (t.is_float16()) {
stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->" if (t.lanes() <= 8) {
<< access[i % 2] << " = " << value << ";\n"; stream << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2] << " = " << value << ";\n";
} else {
stream << "(((half2*)(&(" << vec << "." << access[i / 4] << "))) + "
<< (i / 2 % 2) << ")->" << access[i % 2] << " = " << value
<< ";\n";
}
} else if (t.is_bfloat16()) { } else if (t.is_bfloat16()) {
stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->" if (t.lanes() <= 8) {
<< access[i % 2] << " = " << value << ";\n"; stream << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2] << " = " << value << ";\n";
} else {
stream << "(((nv_bfloat162*)(&(" << vec << "." << access[i / 4]
<< "))) + " << (i / 2 % 2) << ")->" << access[i % 2] << " = "
<< value << ";\n";
}
} else if (t.is_float8()) {
stream << vec;
// fp8_e5_32_t
if (t.lanes() >= 32)
stream << "." << access[i / 16];
// fp8_e5_16_t
if (t.lanes() >= 16)
stream << "." << access[(i % 16) / 8];
// fp8_e5_8_t
if (t.lanes() >= 8)
stream << "." << access[(i % 8) / 4];
// fp8_e5_4_t or fp8_e5_2_t
stream << "." << access[i % 4] << " = " << value << ";\n";
} else if (t.lanes() > 4 && t.lanes() <= 8) { } else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name; std::string type_name;
if (t.bits() == 16) { if (t.bits() == 16) {
...@@ -799,6 +876,9 @@ std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from, ...@@ -799,6 +876,9 @@ std::string CodeGenTileLangCUDA::CastFromTo(std::string value, DataType from,
} }
os << "int)"; os << "int)";
} }
if ((from.is_float16() || from.is_bfloat16()) && target.is_float8()) {
os << "(float)";
}
os << value << ")"; os << value << ")";
return os.str(); return os.str();
} }
...@@ -824,21 +904,25 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { ...@@ -824,21 +904,25 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
bool used_bf16_op = false; bool used_bf16_op = false;
if (from_ty.is_bfloat16() || target_ty.is_bfloat16()) { if (from_ty.is_bfloat16() || target_ty.is_bfloat16()) {
std::ostringstream func_name; std::ostringstream func_name;
if (from_ty.is_bfloat16()) if (from_ty.is_bfloat16()) {
func_name << "bf16"; func_name << "bf16";
else if (from_ty.is_float()) } else if (from_ty.is_float()) {
func_name << "float"; func_name << "float";
if (from_ty.lanes() > 1) }
if (from_ty.lanes() > 1) {
func_name << from_ty.lanes(); func_name << from_ty.lanes();
}
func_name << "2"; func_name << "2";
if (target_ty.is_bfloat16()) if (target_ty.is_bfloat16()) {
func_name << "bf16"; func_name << "bf16";
else if (target_ty.is_float()) } else if (target_ty.is_float()) {
func_name << "float"; func_name << "float";
else if (target_ty == DataType::Int(16)) } else if (target_ty == DataType::Int(16)) {
func_name << "int16"; func_name << "int16";
if (target_ty.lanes() > 1) }
if (target_ty.lanes() > 1) {
func_name << target_ty.lanes(); func_name << target_ty.lanes();
}
auto fname = func_name.str(); auto fname = func_name.str();
if (bf16_supported_ops_.count(fname)) { if (bf16_supported_ops_.count(fname)) {
...@@ -846,20 +930,24 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) { ...@@ -846,20 +930,24 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
stream << "#ifdef ENABLE_BF16\n"; stream << "#ifdef ENABLE_BF16\n";
PrintIndent(); PrintIndent();
stream << "reinterpret_cast<"; stream << "reinterpret_cast<";
if (target_ty.is_bfloat16()) if (target_ty.is_bfloat16()) {
stream << "__nv_bfloat16"; stream << "__nv_bfloat16";
else } else {
PrintType(target_ty.element_of(), stream); PrintType(target_ty.element_of(), stream);
if (target_ty.lanes() > 1) }
if (target_ty.lanes() > 1) {
stream << target_ty.lanes(); stream << target_ty.lanes();
}
stream << " &>(" << sret << ") = fastertransformer::" << fname stream << " &>(" << sret << ") = fastertransformer::" << fname
<< "(reinterpret_cast<"; << "(reinterpret_cast<";
if (from_ty.is_bfloat16()) if (from_ty.is_bfloat16()) {
stream << "__nv_bfloat16"; stream << "__nv_bfloat16";
else } else {
PrintType(from_ty.element_of(), stream); PrintType(from_ty.element_of(), stream);
if (from_ty.lanes() > 1) }
if (from_ty.lanes() > 1) {
stream << from_ty.lanes(); stream << from_ty.lanes();
}
stream << " const &>(" << src << "));\n"; stream << " const &>(" << src << "));\n";
stream << "#else\n"; stream << "#else\n";
} }
...@@ -1006,6 +1094,53 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t, ...@@ -1006,6 +1094,53 @@ std::string CodeGenTileLangCUDA::GetBufferRef(DataType t,
return os.str(); return os.str();
} }
std::string CodeGenTileLangCUDA::GetVecLoad(DataType t,
const BufferNode *buffer,
PrimExpr base) {
const VarNode *buffer_var = buffer->data.get();
std::string scope;
if (alloc_storage_scope_.count(buffer_var)) {
scope = alloc_storage_scope_.at(buffer_var);
}
if (scope.empty()) {
scope = GetPtrStorageScope(buffer->data);
}
if (scope != "global" || t.bits() * t.lanes() <= 128) {
return this->CodeGenC::GetVecLoad(t, buffer, base);
}
ICHECK_EQ(t.bits() * t.lanes(), 256)
<< "Unsupported vector load size: " << t.bits() * t.lanes();
auto buffer_ref = this->GetBufferRef(t, buffer, base);
std::ostringstream os;
os << "tl::ld_global_256(&(" << buffer_ref << "))";
return os.str();
}
void CodeGenTileLangCUDA::PrintVecStore(const BufferNode *buffer, DataType t,
PrimExpr base,
const std::string &value) {
const VarNode *buffer_var = buffer->data.get();
std::string scope;
if (alloc_storage_scope_.count(buffer_var)) {
scope = alloc_storage_scope_.at(buffer_var);
}
if (scope.empty()) {
scope = GetPtrStorageScope(buffer->data);
}
if (scope != "global" || t.bits() * t.lanes() <= 128) {
this->CodeGenC::PrintVecStore(buffer, t, base, value);
return;
}
ICHECK_EQ(t.bits() * t.lanes(), 256)
<< "Unsupported vector load size: " << t.bits() * t.lanes();
auto buffer_ref = this->GetBufferRef(t, buffer, base);
this->PrintIndent();
this->stream << "tl::st_global_256(&(" << buffer_ref << "), " << value
<< ");\n";
}
/** /**
* @brief Emit CUDA/TensorLib-specific code for a call expression. * @brief Emit CUDA/TensorLib-specific code for a call expression.
* *
...@@ -1151,6 +1286,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { ...@@ -1151,6 +1286,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) {
} }
} else if (op->op.same_as(builtin::ptx_cp_async_barrier())) { } else if (op->op.same_as(builtin::ptx_cp_async_barrier())) {
print_extern_call_stmt("tl::mbarrier_cp_async_arrive"); print_extern_call_stmt("tl::mbarrier_cp_async_arrive");
} else if (op->op.same_as(tl::ptx_fence_barrier_init())) {
print_extern_call_stmt("tl::fence_barrier_init");
} else if (op->op.same_as(tl::ptx_cp_async_barrier_noinc())) { } else if (op->op.same_as(tl::ptx_cp_async_barrier_noinc())) {
print_extern_call_stmt("tl::mbarrier_cp_async_arrive_noinc"); print_extern_call_stmt("tl::mbarrier_cp_async_arrive_noinc");
} else if (op->op.same_as(tl::mbarrier_expect_tx())) { } else if (op->op.same_as(tl::mbarrier_expect_tx())) {
...@@ -2004,19 +2141,34 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op, ...@@ -2004,19 +2141,34 @@ void CodeGenTileLangCUDA::VisitExpr_(const BufferLoadNode *op,
void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*) std::ostream &os) { // NOLINT(*)
int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value); int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 && if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8) {
lanes == 4) { if (lanes == 4) {
// make_int8x4 // make_int8x4
const int64_t *p = as_const_int(op->value); const int64_t *p = as_const_int(op->value);
ICHECK(p); ICHECK(p);
int64_t v = *p & 0xFF; int64_t v = *p & 0xFF;
v = (v << 24) | (v << 16) | (v << 8) | v; v = (v << 24) | (v << 16) | (v << 8) | v;
if (op->dtype.is_uint()) { if (op->dtype.is_uint()) {
os << "(uint)" << v; os << "(uint)" << v;
} else { } else {
os << "(int)" << v; os << "(int)" << v;
}
return;
} else if (lanes == 32) {
// make_int8x32
const int64_t *p = as_const_int(op->value);
ICHECK(p);
int64_t v = *p & 0xFF;
v = (v << 24) | (v << 16) | (v << 8) | v;
if (op->dtype.is_uint()) {
os << "make_ulonglong4(" << v << ", " << v << ", " << v << ", " << v
<< ")";
} else {
os << "make_longlong4(" << v << ", " << v << ", " << v << ", " << v
<< ")";
}
return;
} }
return;
} }
if (op->dtype.is_float16()) { if (op->dtype.is_float16()) {
...@@ -2024,10 +2176,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, ...@@ -2024,10 +2176,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
os << "make_"; os << "make_";
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << '('; os << '(';
for (int i = 0; i < lanes / 2; ++i) { if (lanes <= 8) {
if (i != 0) for (int i = 0; i < lanes / 2; ++i) {
os << ", "; if (i != 0)
os << "__pack_half2(" << v << ", " << v << ")"; os << ", ";
os << "__pack_half2(" << v << ", " << v << ")";
}
} else {
for (int i = 0; i < lanes / 4; ++i) {
if (i != 0)
os << ", ";
os << "tl::pack_float16x4(" << v << ", " << v << ", " << v << ", " << v
<< ")";
}
} }
os << ')'; os << ')';
return; return;
...@@ -2038,10 +2199,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op, ...@@ -2038,10 +2199,19 @@ void CodeGenTileLangCUDA::VisitExpr_(const BroadcastNode *op,
os << "make_"; os << "make_";
PrintType(op->dtype, os); PrintType(op->dtype, os);
os << '('; os << '(';
for (int i = 0; i < lanes / 2; ++i) { if (lanes <= 8) {
if (i != 0) for (int i = 0; i < lanes / 2; ++i) {
os << ", "; if (i != 0)
os << "__pack_nv_bfloat162(" << v << ", " << v << ")"; os << ", ";
os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
}
} else {
for (int i = 0; i < lanes / 4; ++i) {
if (i != 0)
os << ", ";
os << "tl::pack_bfloat16x4(" << v << ", " << v << ", " << v << ", " << v
<< ")";
}
} }
os << ')'; os << ')';
return; return;
......
...@@ -36,6 +36,10 @@ public: ...@@ -36,6 +36,10 @@ public:
std::ostream &os) final; // NOLINT(*) std::ostream &os) final; // NOLINT(*)
void PrintVecElemStore(const std::string &vec, DataType t, int i, void PrintVecElemStore(const std::string &vec, DataType t, int i,
const std::string &value) final; const std::string &value) final;
std::string GetVecLoad(DataType t, const BufferNode *buffer,
PrimExpr base) final;
void PrintVecStore(const BufferNode *buffer, DataType t, PrimExpr base,
const std::string &value) final;
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*) void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string &value, void PrintVecElemLoadExpr(DataType t, int i, const std::string &value,
std::ostream &os) final; std::ostream &os) final;
......
...@@ -53,6 +53,13 @@ bool TargetIsHopper(Target target) { ...@@ -53,6 +53,13 @@ bool TargetIsHopper(Target target) {
return arch >= 90 && arch < 100; return arch >= 90 && arch < 100;
} }
bool TargetIsSm100(Target target) {
if (!TargetIsCuda(target))
return false;
int arch = GetArchInt(target);
return arch >= 100 & arch <= 103;
}
bool TargetIsSM120(Target target) { bool TargetIsSM120(Target target) {
if (!TargetIsCuda(target)) if (!TargetIsCuda(target))
return false; return false;
...@@ -104,6 +111,12 @@ bool TargetHasStmatrix(Target target) { ...@@ -104,6 +111,12 @@ bool TargetHasStmatrix(Target target) {
return arch >= 90; return arch >= 90;
} }
bool TargetHasTmem(Target target) {
if (!TargetIsCuda(target))
return false;
return TargetIsSm100(target);
}
bool TargetHasBulkCopy(Target target) { bool TargetHasBulkCopy(Target target) {
if (!TargetIsCuda(target)) if (!TargetIsCuda(target))
return false; return false;
......
...@@ -19,12 +19,14 @@ bool TargetIsVolta(Target target); ...@@ -19,12 +19,14 @@ bool TargetIsVolta(Target target);
bool TargetIsTuring(Target target); bool TargetIsTuring(Target target);
bool TargetIsAmpere(Target target); bool TargetIsAmpere(Target target);
bool TargetIsHopper(Target target); bool TargetIsHopper(Target target);
bool TargetIsSm100(Target target);
bool TargetIsSM120(Target target); bool TargetIsSM120(Target target);
bool TargetIsCDNA(Target target); bool TargetIsCDNA(Target target);
bool TargetHasAsyncCopy(Target target); bool TargetHasAsyncCopy(Target target);
bool TargetHasLdmatrix(Target target); bool TargetHasLdmatrix(Target target);
bool TargetHasStmatrix(Target target); bool TargetHasStmatrix(Target target);
bool TargetHasTmem(Target target);
bool TargetHasBulkCopy(Target target); bool TargetHasBulkCopy(Target target);
int TargetGetWarpSize(Target target); int TargetGetWarpSize(Target target);
......
...@@ -2,9 +2,14 @@ ...@@ -2,9 +2,14 @@
#include "common.h" #include "common.h"
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) #ifdef __CUDA_ARCH_LIST__
#if __CUDA_ARCH_LIST__ >= 900
#include "copy_sm90.h" #include "copy_sm90.h"
#endif #endif
#if __CUDA_ARCH_LIST__ >= 1000
#include "copy_sm100.h"
#endif
#endif
namespace tl { namespace tl {
......
#pragma once
#include "cuda_fp8.h"
#include "tcgen_05.h"
#include "tcgen_05_ld.h"
namespace tl {
__device__ __forceinline__ longlong4 ld_global_256(const longlong4 *ptr) {
longlong4 ret;
asm volatile("ld.global.v4.s64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
: "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_global_256(longlong4 *ptr, longlong4 &val) {
asm volatile("st.global.v4.s64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
__device__ __forceinline__ ulonglong4 ld_global_256(const ulonglong4 *ptr) {
ulonglong4 ret;
asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
: "l"(ptr));
return ret;
}
// must be const &val, otherwise the compiler will generate a temporary variable
// and compilation will fail if we have st_global_256(ptr, ld_global_256(ptr))
__device__ __forceinline__ void st_global_256(ulonglong4 *ptr,
const ulonglong4 &val) {
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
__device__ __forceinline__ ulonglong4 ld_global_256(const fp8_e4_32_t *ptr) {
ulonglong4 ret;
asm volatile("ld.global.v4.u64 {%0, %1, %2, %3}, [%4];"
: "=l"(ret.x), "=l"(ret.y), "=l"(ret.z), "=l"(ret.w)
: "l"(ptr));
return ret;
}
__device__ __forceinline__ void st_global_256(fp8_e4_32_t *ptr,
fp8_e4_32_t &val8) {
ulonglong4 &val = *((ulonglong4 *)&val8);
asm volatile("st.global.v4.u64 [%0], {%1, %2, %3, %4};"
:
: "l"(ptr), "l"(val.x), "l"(val.y), "l"(val.z), "l"(val.w));
}
__device__ __forceinline__ unsigned long long
pack_bfloat16x4(const bfloat16_t x, const bfloat16_t y, const bfloat16_t z,
const bfloat16_t w) {
unsigned long long v0 = *((unsigned short *)&x);
unsigned long long v1 = *((unsigned short *)&y);
unsigned long long v2 = *((unsigned short *)&z);
unsigned long long v3 = *((unsigned short *)&w);
return (v0 | (v1 << 16) | (v2 << 32) | (v3 << 48));
}
__device__ __forceinline__ unsigned long long
pack_float16x4(const half x, const half y, const half z, const half w) {
unsigned long long v0 = *((unsigned short *)&x);
unsigned long long v1 = *((unsigned short *)&y);
unsigned long long v2 = *((unsigned short *)&z);
unsigned long long v3 = *((unsigned short *)&w);
return (v0 | (v1 << 16) | (v2 << 32) | (v3 << 48));
}
// Helper function to find the largest K that 2**K <= N
// Requires N > 0
template <int N, int K = 0>
__device__ __forceinline__ constexpr int get_floor_log2() {
static_assert(N > 0);
if constexpr ((1 << (K + 1)) > N)
return K;
else
return get_floor_log2<N, K + 1>();
}
template <typename target_call_cls, int MAX_LOGN, int N, typename dst_t>
__device__ __forceinline__ void tcgen05_ld_core(uint32_t const &tmem_start_col,
dst_t *dst_ptr) {
static_assert(N > 0);
constexpr int LOG_N = get_floor_log2<N>();
constexpr int CUR_SEGMENT_LEN = 1 << (LOG_N > MAX_LOGN ? MAX_LOGN : LOG_N);
target_call_cls::copy<CUR_SEGMENT_LEN>(tmem_start_col, (uint32_t *)dst_ptr);
if constexpr (N - CUR_SEGMENT_LEN > 0) {
tcgen05_ld_core<target_call_cls, MAX_LOGN, N - CUR_SEGMENT_LEN>(
tmem_start_col + CUR_SEGMENT_LEN, dst_ptr + CUR_SEGMENT_LEN);
}
}
template <int N, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp32bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp32bNx, 7, N>(tmem_start_col + tmem_col_offset,
dst_ptr);
tl::fence_view_async_tmem_load();
}
template <int N, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp64bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp64bNx, 7, N>(tmem_start_col + tmem_col_offset,
dst_ptr);
tl::fence_view_async_tmem_load();
}
template <int N, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp128bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp128bNx, 6, N>(
tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load();
}
template <int N, typename dst_t>
__device__ __forceinline__ void
tcgen05_ld_32dp256bNx(uint32_t const &tmem_start_col,
uint32_t const &tmem_col_offset, dst_t *dst_ptr) {
tcgen05_ld_core<tl::tmem_ld_32dp256bNx, 5, N>(
tmem_start_col + tmem_col_offset, dst_ptr);
tl::fence_view_async_tmem_load();
}
} // namespace tl
#pragma once #pragma once
#include <cuda_fp8.h>
#include <cute/numeric/numeric_types.hpp> #include <cute/numeric/numeric_types.hpp>
using fp8_e4_t = cute::float_e4m3_t; using fp8_e4_t = cute::float_e4m3_t;
...@@ -27,6 +28,19 @@ struct __CUDA_ALIGN__(16) fp8_e4_16_t { ...@@ -27,6 +28,19 @@ struct __CUDA_ALIGN__(16) fp8_e4_16_t {
fp8_e4_8_t y; fp8_e4_8_t y;
}; };
struct __CUDA_ALIGN__(32) fp8_e4_32_t {
fp8_e4_16_t x;
fp8_e4_16_t y;
__device__ __forceinline__ fp8_e4_32_t &operator=(const ulonglong4 &rhs) {
x.x = *(fp8_e4_8_t *)&rhs.x;
x.y = *(fp8_e4_8_t *)&rhs.y;
y.x = *(fp8_e4_8_t *)&rhs.z;
y.y = *(fp8_e4_8_t *)&rhs.w;
return *this;
}
};
struct __CUDA_ALIGN__(2) fp8_e5_2_t { struct __CUDA_ALIGN__(2) fp8_e5_2_t {
fp8_e5_t x; fp8_e5_t x;
fp8_e5_t y; fp8_e5_t y;
...@@ -48,3 +62,16 @@ struct __CUDA_ALIGN__(16) fp8_e5_16_t { ...@@ -48,3 +62,16 @@ struct __CUDA_ALIGN__(16) fp8_e5_16_t {
fp8_e5_8_t x; fp8_e5_8_t x;
fp8_e5_8_t y; fp8_e5_8_t y;
}; };
struct __CUDA_ALIGN__(32) fp8_e5_32_t {
fp8_e5_16_t x;
fp8_e5_16_t y;
__device__ __forceinline__ fp8_e5_32_t &operator=(const ulonglong4 &rhs) {
x.x = *(fp8_e5_8_t *)&rhs.x;
x.y = *(fp8_e5_8_t *)&rhs.y;
y.x = *(fp8_e5_8_t *)&rhs.z;
y.y = *(fp8_e5_8_t *)&rhs.w;
return *this;
}
};
...@@ -48,6 +48,16 @@ template <> __device__ void debug_print_var<int>(const char *msg, int var) { ...@@ -48,6 +48,16 @@ template <> __device__ void debug_print_var<int>(const char *msg, int var) {
threadIdx.z, var); threadIdx.z, var);
} }
// Specialization for unsigned integer type
template <>
__device__ void debug_print_var<unsigned int>(const char *msg,
unsigned int var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=int "
"value=%u\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, var);
}
// Specialization for float type // Specialization for float type
template <> __device__ void debug_print_var<float>(const char *msg, float var) { template <> __device__ void debug_print_var<float>(const char *msg, float var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float " printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): dtype=float "
...@@ -149,6 +159,17 @@ __device__ void debug_print_buffer_value<int>(const char *msg, ...@@ -149,6 +159,17 @@ __device__ void debug_print_buffer_value<int>(const char *msg,
threadIdx.z, buf_name, index, var); threadIdx.z, buf_name, index, var);
} }
// Specialization for unsigned integer type
template <>
__device__ void
debug_print_buffer_value<unsigned int>(const char *msg, const char *buf_name,
int index, unsigned int var) {
printf("msg='%s' BlockIdx=(%d, %d, %d), ThreadIdx=(%d, %d, %d): buffer=%s, "
"index=%d, dtype=int value=%u\n",
msg, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y,
threadIdx.z, buf_name, index, var);
}
// Specialization for float type // Specialization for float type
template <> template <>
__device__ void debug_print_buffer_value<float>(const char *msg, __device__ void debug_print_buffer_value<float>(const char *msg,
......
#pragma once #pragma once
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1200)) #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1200))
#include "gemm_sm120.h" #include "gemm_sm120.h"
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 1000))
#include "gemm_sm100.h"
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
#include "gemm_sm90.h" #include "gemm_sm90.h"
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890)) #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 890))
...@@ -10,5 +13,5 @@ ...@@ -10,5 +13,5 @@
#elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 700)) #elif (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 700))
#include "gemm_sm70.h" #include "gemm_sm70.h"
#else #else
// No matching architecture found
#endif #endif
// Licensed under the MIT License.
#pragma once
#include "common.h"
#include "gemm_mma.h"
#include "intrin.h"
#include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/collective/collective_builder.hpp>
namespace cute {
// Extensions to CuTe
// CuTe don't support TCGEN5MMA with .ws, so we add it here
// About why we need .ws, plz refer to comments in tl_tcgen5mma::GemmTensorOp
template <class a_type, class b_type, class c_type, int M, int N,
UMMA::Major a_major, UMMA::Major b_major,
UMMA::ScaleIn a_neg = UMMA::ScaleIn::One,
UMMA::ScaleIn b_neg = UMMA::ScaleIn::One>
struct SM100_MMA_F16BF16_WS_SS {
static_assert(M == 32 || M == 64 || M == 128,
"SM100_MMA_F16BF16 (with .ws) M-mode size should be 32, 64 or "
"128 for 1 CTA cluster MMA.");
static_assert(
N == 64 || N == 128 || N == 256,
"SM100_MMA_F16BF16 (with .ws) N-mode size should be 32, 64 or 128");
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scaleC, uint64_t const &idescE) {
if (cute::elect_one_sync()) {
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE >> 32)),
"r"(scaleC));
}
}
};
template <class a_type, class b_type, class c_type, int M, int N,
UMMA::Major a_major, UMMA::Major b_major, UMMA::ScaleIn a_neg,
UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F16BF16_WS_SS<a_type, b_type, c_type, M, N, a_major,
b_major, a_neg, b_neg>> {
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> == cute::sizeof_bits_v<b_type> &&
cute::sizeof_bits_v<b_type> == 16,
"SM100_MMA_F16BF16_WS_SS supports 16bit types");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>;
// Logical shape-K is always 256bits, transform to units of elements
static constexpr int K = 256 / cute::sizeof_bits<ValTypeA>::value;
using Shape_MNK = Shape<Int<M>, Int<N>, Int<K>>;
using ThrID = Layout<_1>;
using ALayout =
Layout<Shape<_1, Shape<Int<M>, Int<K>>>, Stride<_0, Stride<_1, Int<M>>>>;
using BLayout =
Layout<Shape<_1, Shape<Int<N>, Int<K>>>, Stride<_0, Stride<_1, Int<N>>>>;
using CLayout =
Layout<Shape<_1, Shape<Int<M>, Int<N>>>, Stride<_0, Stride<_1, Int<M>>>>;
UMMA::InstrDescriptor idesc_ =
UMMA::make_instr_desc<a_type, b_type, c_type, M, N, a_major, b_major,
a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout, class TA, class ALayout, class TB,
class BLayout, class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend void
mma_unpack(MMA_Traits const &traits, Tensor<TD, DLayout> &D,
Tensor<TA, ALayout> const &A, Tensor<TB, BLayout> const &B,
Tensor<TC, CLayout> const &C) {
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value,
"Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value,
"Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F16BF16_WS_SS<a_type, b_type, c_type, M, N, a_major, b_major,
a_neg, b_neg>::fma(desc_a, desc_b, tmem_c,
uint32_t(traits.accumulate_),
idesc);
}
};
struct SM100_MMA_F8F6F4_WS_SS {
using DRegisters = void;
using ARegisters = uint64_t[1];
using BRegisters = uint64_t[1];
using CRegisters = uint32_t[1];
CUTE_HOST_DEVICE static void
fma(uint64_t const &desc_a, uint64_t const &desc_b, uint32_t const &tmem_c,
uint32_t const &scaleC, uint64_t const &idescE) {
if (cute::elect_one_sync()) {
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.ws.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, "
"p, 0; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b),
"r"(uint32_t(idescE >> 32)), "r"(scaleC));
}
}
};
template <class a_type, class b_type, class c_type, int M, int N,
UMMA::Major a_major, UMMA::Major b_major, UMMA::ScaleIn a_neg,
UMMA::ScaleIn b_neg>
struct MMA_Traits<SM100_MMA_F8F6F4_WS_SS, a_type, b_type, c_type, cute::C<M>,
cute::C<N>, cute::integral_constant<UMMA::Major, a_major>,
cute::integral_constant<UMMA::Major, b_major>,
cute::integral_constant<UMMA::ScaleIn, a_neg>,
cute::integral_constant<UMMA::ScaleIn, b_neg>> {
using ValTypeD = c_type;
using ValTypeA = a_type;
using ValTypeB = b_type;
using ValTypeC = c_type;
static_assert(cute::sizeof_bits_v<a_type> <= 8 &&
cute::sizeof_bits_v<b_type> <= 8,
"SM100_MMA_F8F6F4_WS_SS supports types with leq 8bit types");
static_assert(M == 32 || M == 64 || M == 128,
"SM100_MMA_F8F6F4_WS_SS M-mode size should be 32, 64 or 128 "
"for 1 CTA cluster MMA.");
static_assert(
N == 64 || N == 128 || N == 256,
"SM100_MMA_F8F6F4_WS_SS (with .ws) N-mode size should be 32, 64 or 128");
using FrgTypeA = UMMA::smem_desc<a_major>;
using FrgTypeB = UMMA::smem_desc<b_major>;
using FrgTypeC = UMMA::tmem_frg_ws_1sm<c_type>;
static_assert(sizeof_bits_v<ValTypeA> <= sizeof_bits_v<uint8_t> &&
sizeof_bits_v<ValTypeB> <= sizeof_bits_v<uint8_t>);
// Logical shape-K is always 256bits, transform to units of elements
constexpr static int K = 32;
using Shape_MNK = Shape<Int<M>, Int<N>, Int<K>>;
using ThrID = Layout<_1>;
using ALayout =
Layout<Shape<_1, Shape<Int<M>, Int<K>>>, Stride<_0, Stride<_1, Int<M>>>>;
using BLayout =
Layout<Shape<_1, Shape<Int<N>, Int<K>>>, Stride<_0, Stride<_1, Int<N>>>>;
using CLayout =
Layout<Shape<_1, Shape<Int<M>, Int<N>>>, Stride<_0, Stride<_1, Int<M>>>>;
UMMA::InstrDescriptor idesc_ =
UMMA::make_instr_desc<a_type, b_type, c_type, M, N, a_major, b_major,
a_neg, b_neg>();
// Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators]
UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One;
template <class TD, class DLayout, class TA, class ALayout, class TB,
class BLayout, class TC, class CLayout>
CUTE_HOST_DEVICE constexpr friend void
mma_unpack(MMA_Traits const &traits, Tensor<TD, DLayout> &D,
Tensor<TA, ALayout> const &A, Tensor<TB, BLayout> const &B,
Tensor<TC, CLayout> const &C) {
static_assert(is_tmem<TD>::value, "Expected tmem in MMA_Atom::call");
static_assert(is_rmem<TA>::value,
"Expected desc registers in MMA_Atom::call");
static_assert(is_rmem<TB>::value,
"Expected desc registers in MMA_Atom::call");
static_assert(is_tmem<TC>::value, "Expected tmem in MMA_Atom::call");
uint64_t desc_a = A[0];
uint64_t desc_b = B[0];
uint32_t tmem_c = raw_pointer_cast(D.data());
uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_);
SM100_MMA_F8F6F4_WS_SS::fma(desc_a, desc_b, tmem_c,
uint32_t(traits.accumulate_), idesc);
}
};
namespace tl_tcgen5mma {
using cutlass::gemm::collective::detail::sm100_smem_selector;
template <typename A_type, typename B_type, typename C_type, int M, int N,
int K, UMMA::Major a_major, UMMA::Major b_major,
typename Enable = void>
struct DispatchInstruction;
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, M, N, K, a_major,
b_major, std::enable_if_t<M == 128 && K == 16>> {
using MMA = SM100_MMA_F16BF16_SS<bfloat16_t, bfloat16_t, float, M, N, a_major,
b_major>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<bfloat16_t, bfloat16_t, float, M, N, K, a_major,
b_major,
std::enable_if_t<(M == 64 || M == 32) && K == 16>> {
using MMA = SM100_MMA_F16BF16_WS_SS<bfloat16_t, bfloat16_t, float, M, N,
a_major, b_major>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<half_t, half_t, float, M, N, K, a_major, b_major,
std::enable_if_t<M == 128 && K == 16>> {
using MMA =
SM100_MMA_F16BF16_SS<half_t, half_t, float, M, N, a_major, b_major>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<half_t, half_t, float, M, N, K, a_major, b_major,
std::enable_if_t<(M == 64 || M == 32) && K == 16>> {
using MMA =
SM100_MMA_F16BF16_WS_SS<half_t, half_t, float, M, N, a_major, b_major>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, M, N, K, a_major, b_major,
std::enable_if_t<M == 128 && K == 32>> {
using MMA = MMA_Traits<SM100_MMA_F8F6F4_SS, fp8_e4_t, fp8_e4_t, float, Int<M>,
Int<N>, integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, M, N, K, a_major, b_major,
std::enable_if_t<(M == 64 || M == 32) && K == 32>> {
using MMA =
MMA_Traits<SM100_MMA_F8F6F4_WS_SS, fp8_e4_t, fp8_e4_t, float, Int<M>,
Int<N>, integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, M, N, K, a_major, b_major,
std::enable_if_t<M == 128 && K == 32>> {
using MMA = MMA_Traits<SM100_MMA_F8F6F4_SS, fp8_e5_t, fp8_e5_t, float, Int<M>,
Int<N>, integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, UMMA::Major a_major, UMMA::Major b_major>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, M, N, K, a_major, b_major,
std::enable_if_t<(M == 64 || M == 32) && K == 32>> {
using MMA =
MMA_Traits<SM100_MMA_F8F6F4_WS_SS, fp8_e5_t, fp8_e5_t, float, Int<M>,
Int<N>, integral_constant<UMMA::Major, a_major>,
integral_constant<UMMA::Major, b_major>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>,
integral_constant<UMMA::ScaleIn, UMMA::ScaleIn::One>>;
};
template <int M, int N, int K, int AtomM, int AtomN, int AtomK, bool trans_A,
bool trans_B, typename A_type_raw, typename B_type_raw,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>::type;
using C_type = C_type_raw;
static_assert(AtomM == 128 || AtomM == 64 || AtomM == 32);
static constexpr UMMA::Major UmmaMajorA =
trans_A ? UMMA::Major::MN : UMMA::Major::K;
static constexpr UMMA::Major UmmaMajorB =
trans_B ? UMMA::Major::K : UMMA::Major::MN;
using SmemLayoutAtomA =
decltype(sm100_smem_selector<UmmaMajorA, A_type, Int<M>, Int<K>>());
using SmemLayoutAtomB =
decltype(sm100_smem_selector<UmmaMajorB, B_type, Int<N>, Int<K>>());
using SmemLayoutA = decltype(tile_to_shape(
SmemLayoutAtomA{}, Shape<Int<M>, Int<K>>{},
conditional_t<trans_A, Step<_2, _1>, Step<_1, _2>>{}));
using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{}, Shape<Int<N>, Int<K>>{},
conditional_t<trans_B, Step<_1, _2>, Step<_2, _1>>{}));
static CUTE_DEVICE void body_ss(A_type_raw *pA, B_type_raw *pB, uint32_t pC,
uint64_t *umma_bar_ptr, bool clear_accum) {
Tensor sA = make_tensor(make_smem_ptr(reinterpret_cast<A_type *>(pA)),
SmemLayoutA{});
Tensor sB = make_tensor(make_smem_ptr(reinterpret_cast<B_type *>(pB)),
SmemLayoutB{});
// TODO (lei): Normal TCGEN5MMA (the one w/o ws) don't saturate all 128
// lanes when M == 64
// (see layout F in
// https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-data-path-layout-f)
// So we use the .ws variant here
using MmaAtom =
typename DispatchInstruction<A_type, B_type, C_type, AtomM, AtomN,
AtomK, UmmaMajorA, UmmaMajorB>::MMA;
auto tiled_mma = make_tiled_mma(MmaAtom{}, Layout<Shape<_1>>{},
Tile<Int<M>, Int<N>, Int<K>>{});
auto thr_mma = tiled_mma.get_slice(_0{});
tiled_mma.accumulate_ =
clear_accum ? UMMA::ScaleOut::Zero : UMMA::ScaleOut::One;
Tensor acc = partition_fragment_C(tiled_mma, Shape<Int<M>, Int<N>>{});
acc.data() = pC;
Tensor sA_frag = thr_mma.partition_fragment_A(sA);
Tensor sB_frag = thr_mma.partition_fragment_B(sB);
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(sA_frag); ++k_block) {
cute::gemm(tiled_mma, sA_frag(_, _, k_block), sB_frag(_, _, k_block),
acc);
tiled_mma.accumulate_ = UMMA::ScaleOut::One;
}
cutlass::arch::umma_arrive(umma_bar_ptr);
}
};
} // namespace tl_tcgen5mma
} // namespace cute
namespace tl {
using tl_mma::gemm_rs;
using tl_mma::gemm_sr;
using tl_mma::gemm_ss;
// TODO (lei): Implement gemm_ts
// template <int M, int N, int K, int warp_m, int warp_n, bool trans_A, bool
// trans_B, bool clear_accum, typename A_type, typename B_type, typename C_type>
// TL_DEVICE void gemm_ts(A_type *pA, B_type *pB, C_type *accum, uint64_t
// *umma_bar_ptr) {
// }
template <int M, int N, int K, int AtomM, int AtomN, int AtomK, bool trans_A,
bool trans_B, typename C_type, typename A_type, typename B_type>
TL_DEVICE void tcgen5mma_gemm_ss(A_type *pA, B_type *pB, uint32_t accum,
uint64_t *umma_bar_ptr, bool clear_accum) {
using MMA =
cute::tl_tcgen5mma::GemmTensorOp<M, N, K, AtomM, AtomN, AtomK, trans_A,
trans_B, A_type, B_type, C_type>;
MMA::body_ss(pA, pB, accum, umma_bar_ptr, clear_accum);
}
} // namespace tl
#pragma once
#include <cstdint>
#ifndef __CUDACC_RTC__
#include <cuda.h>
#endif
#include "common.h"
namespace tl {
TL_DEVICE void tmem_allocate(void *dst_ptr, int num_columns) {
uint32_t dst_intptr = smem_ptr_to_uint(dst_ptr);
asm volatile(
"tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
:
: "r"(dst_intptr), "r"(num_columns));
}
TL_DEVICE void tmem_deallocate(uint32_t *tmem_ptr, int num_columns) {
asm volatile("{\n\t"
"tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1; \n\t"
"}"
:
: "r"(*tmem_ptr), "r"(num_columns));
}
inline void __device__ fence_view_async_tmem_load() {
asm volatile("tcgen05.wait::ld.sync.aligned; " ::);
}
inline void __device__ fence_view_async_tmem_store() {
asm volatile("tcgen05.wait::st.sync.aligned; " ::);
}
template <int M, int N>
inline void __device__ amma_fp16bf16_ss(uint64_t const desc_a,
uint64_t const desc_b,
uint32_t const tmem_c,
uint32_t const idesc,
uint32_t const addC = 1) {
static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16 M-mode size should be "
"64 or 128 for 1 CTA cluster MMA.");
static_assert(
(M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) ||
(M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)),
"SM100_MMA_F16BF16 N-mode size should be a multiple of 8 between 8 and 256 for M=64,\
or a multiple of 16 between 16 and 256 for M=128.");
uint32_t mask[4] = {0, 0, 0, 0};
asm volatile("{\n\t"
".reg .pred p;\n\t"
"setp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, "
"%7, %8}, p; \n\t"
"}\n"
:
: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(idesc), "r"(addC),
"r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]));
}
inline __device__ void amma_commit(uint64_t const *smem_ptr) {
uint32_t bar_intptr = smem_ptr_to_uint(smem_ptr);
asm volatile("tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::"
"cluster.b64 [%0];"
:
: "r"(bar_intptr));
}
} // namespace tl
\ No newline at end of file
This diff is collapsed.
...@@ -23,7 +23,8 @@ ...@@ -23,7 +23,8 @@
*/ */
#include "loop_vectorize.h" #include "loop_vectorize.h"
#include "../op/builtin.h"
#include "../target/utils.h"
#include "arith/int_operator.h" #include "arith/int_operator.h"
#include "arith/ir_visitor_with_analyzer.h" #include "arith/ir_visitor_with_analyzer.h"
#include "common/loop_vectorization_utils.h" #include "common/loop_vectorization_utils.h"
...@@ -44,11 +45,48 @@ struct VectorizePlanResult { ...@@ -44,11 +45,48 @@ struct VectorizePlanResult {
PrimExpr condition; PrimExpr condition;
}; };
class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer {
public:
VectorizeFindGlobalAccess() = default;
bool HasGlobalAccess(const Stmt &stmt) {
this->operator()(stmt);
return has_global_access_;
}
private:
bool has_global_access_ = false;
void VisitStmt_(const BufferStoreNode *node) final {
if (node->buffer.scope() == "global")
has_global_access_ = true;
return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
}
void VisitExpr_(const BufferLoadNode *node) final {
if (node->buffer.scope() == "global")
has_global_access_ = true;
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
};
class VectorizePlanner : public arith::IRVisitorWithAnalyzer { class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
public: public:
VectorizePlanner() = default; VectorizePlanner() = default;
int Plan(const For &node) { int Plan(const For &node) {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
Optional<Bool> opt_disable_vectorize_256 =
ctxt->GetConfig(kDisableVectorize256, Optional<Bool>());
bool disable_vectorize_256 =
opt_disable_vectorize_256.value_or(Bool(false));
if (tvm::tl::TargetIsSm100(Target::Current(false)) &&
!disable_vectorize_256 &&
VectorizeFindGlobalAccess().HasGlobalAccess(node)) {
vector_load_bits_max_ = vector_size_ = 256;
} else {
vector_load_bits_max_ = vector_size_ = 128;
}
this->operator()(node); this->operator()(node);
return vector_size_; return vector_size_;
} }
...@@ -110,7 +148,13 @@ private: ...@@ -110,7 +148,13 @@ private:
// TODO: perform some checks here // TODO: perform some checks here
} }
void UpdateVectorSize(const Array<PrimExpr> &indices, const Buffer &buffer) { void VisitExpr_(const CastNode *node) final {
vector_size_ = arith::ZeroAwareGCD(
vector_load_bits_max_ / node->dtype.bits(), vector_size_);
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
}
void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
if (!inner_for_) if (!inner_for_)
return; return;
// 1. Compute raw element offset // 1. Compute raw element offset
...@@ -144,7 +188,7 @@ private: ...@@ -144,7 +188,7 @@ private:
} }
} }
const int vector_load_bits_max_ = 128; int vector_load_bits_max_;
const ForNode *inner_for_{}; const ForNode *inner_for_{};
bool has_nonlocal_memory_access_ = false; bool has_nonlocal_memory_access_ = false;
......
/*!
* \file lower_shared_tmem.cc
* \brief Convert shared.tmem buffers to plain shared + ptx init, and do
* coordinate translation (from logical address to physical address)
*/
#include "../op/builtin.h"
#include "../target/utils.h"
#include "tvm/ir/type.h"
#include "tvm/tir/builtin.h"
#include "tvm/tir/expr.h"
#include "tvm/tir/stmt.h"
#include <tvm/arith/analyzer.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
namespace tvm {
namespace tl {
using namespace tir;
class SharedTmemRewriter : public StmtExprMutator {
public:
static Stmt Rewrite(Stmt body) {
SharedTmemRewriter rewriter;
return rewriter(body);
}
private:
Stmt VisitStmt_(const BlockNode *op) final {
Block block = GetRef<Block>(op);
Array<Buffer> alloc_buffers = op->alloc_buffers;
if (op->annotations.count(attr::kLayoutMap)) {
auto layout_map = op->annotations.Get(attr::kLayoutMap);
ICHECK(layout_map) << "layout map is not defined";
layout_map_ = layout_map->as<Map<Buffer, Layout>>().value();
}
// Record the mapping from buffer data var to buffer for later lookup
for (auto buffer : alloc_buffers) {
buffer_map_.insert({buffer->data, buffer});
}
for (auto match_buffer : op->match_buffers) {
buffer_map_.insert({match_buffer->buffer->data, match_buffer->buffer});
}
Array<Buffer> tmem_buffers;
for (const auto &[data, buffer] : buffer_map_) {
const auto *ptr_type =
buffer->data->type_annotation.as<PointerTypeNode>();
auto storage_scope = ptr_type->storage_scope;
ICHECK(ptr_type) << "Buffer Var's type annotation must be of PointerType";
if (storage_scope == "shared.tmem") {
tmem_buffers.push_back(buffer);
}
}
if (tmem_buffers.empty()) {
return StmtExprMutator::VisitStmt_(op);
}
ICHECK(thread_var_.defined()) << "thread_var_ is not defined";
for (auto buffer : tmem_buffers) {
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
/*
Transform the tmem buffers to new allocations
transform:
tmem_buf0 = T.alloc_buffer((128, 128,), "uint64",
scope="shared.tmem")
tmem_buf1 = T.alloc_buffer((128, 128,), "uint64",
scope="shared.tmem")
into:
tmem_buf0 = T.alloc_buffer((1,), "uint64", scope="shared.tmem_addr")
tmem_buf1 = T.alloc_buffer((1,), "uint64", scope="shared.tmem_addr")
if tx == 0:
T.ptx_init_tensor_memory(tmem_buf0[0], 128)
T.ptx_init_tensor_memory(tmem_buf1[0], 128)
*/
// 1. create new data vars
Array<Var> new_data_vars;
for (auto buffer : tmem_buffers) {
auto data = buffer->data;
auto new_data =
Var(data->name_hint, PointerType(PrimType(tmem_dtype_), "shared"));
var_remap_.Set(data, new_data);
new_data_vars.push_back(new_data);
}
// 2. create new buffers
Array<Buffer> new_buffers;
for (auto buffer : tmem_buffers) {
auto data = buffer->data;
ICHECK(var_remap_.find(data) != var_remap_.end())
<< "data not found in var_remap_";
auto new_data = var_remap_.at(data);
auto new_buffer = Buffer(new_data, tmem_dtype_, Array<PrimExpr>({1}),
Array<PrimExpr>({1}), PrimExpr(0), buffer->name,
buffer->data_alignment, buffer->offset_factor,
buffer->buffer_type);
new_buffers.push_back(new_buffer);
buffer_remap_.Set(buffer, new_buffer);
}
// remove the tmem buffers
alloc_buffers.MutateByApply([this](Buffer buf) {
if (buffer_remap_.find(buf) != buffer_remap_.end()) {
return buffer_remap_.at(buf);
}
return buf;
});
if (!alloc_buffers.same_as(op->alloc_buffers)) {
block.CopyOnWrite()->alloc_buffers = alloc_buffers;
} else {
return StmtExprMutator::VisitStmt_(op);
}
// 3. create init & dealloc calls for new buffers
std::vector<Stmt> init_mtmem_calls_;
std::vector<Stmt> dealloc_tmem_calls_;
for (auto buffer : tmem_buffers) {
auto data = buffer->data;
auto old_buffer = buffer_data_to_buffer_.at(data);
auto new_buffer = buffer_remap_.at(old_buffer);
// Tmem physical coord range analysis
ICHECK(old_buffer->shape.size() == 2);
auto analyzer = std::make_shared<arith::Analyzer>();
arith::ConstIntBound phy_col_bounds =
analyzer->const_int_bound(old_buffer->shape[1]);
int num_cols_required = phy_col_bounds->max_value;
ICHECK(num_cols_required <= 512)
<< "The number of columns required for tmem buffer "
<< old_buffer->name << " is " << num_cols_required
<< ", which exceeds the maximum of 512 columns";
int num_cols_allocated = 32; // Align num_cols_allocated to power of 2
for (; num_cols_allocated < num_cols_required; num_cols_allocated *= 2)
;
auto new_buffer_access = new_buffer.access_ptr(1, DataType::Handle(), 1,
PrimExpr(0), PrimExpr(1));
auto alloc_call = Call(DataType::Handle(), tl::ptx_init_tensor_memory(),
{new_buffer_access, PrimExpr(num_cols_allocated)});
init_mtmem_calls_.push_back(Evaluate(alloc_call));
auto dealloc_call =
Call(DataType::Handle(), tl::ptx_deallocate_tensor_memory(),
{new_buffer_access, PrimExpr(num_cols_allocated)});
dealloc_tmem_calls_.push_back(Evaluate(dealloc_call));
}
auto compare_by_buffer_name = [&](const Stmt &a, const Stmt &b) {
auto call_a = a.as<EvaluateNode>()->value.as<CallNode>();
auto call_b = b.as<EvaluateNode>()->value.as<CallNode>();
auto num_cols_a = call_a->args[1].as<IntImmNode>()->value;
auto num_cols_b = call_b->args[1].as<IntImmNode>()->value;
return num_cols_a > num_cols_b;
};
std::sort(init_mtmem_calls_.begin(), init_mtmem_calls_.end(),
compare_by_buffer_name);
Array<Stmt> new_body;
auto target = Target::Current();
auto warp_size = TargetGetWarpSize(target);
auto thread_var_div_warp_size =
FloorDiv(thread_var_->var, IntImm(thread_var_->var->dtype, warp_size));
new_body.push_back(IfThenElse(EQ(thread_var_div_warp_size, 0),
init_mtmem_calls_.size() > 1
? SeqStmt(init_mtmem_calls_)
: init_mtmem_calls_.back(),
Stmt()));
new_body.push_back(
Evaluate(Call(DataType::Handle(), builtin::tvm_storage_sync(),
{StringImm("shared")})));
new_body.push_back(block->body);
new_body.push_back(IfThenElse(EQ(thread_var_div_warp_size, 0),
dealloc_tmem_calls_.size() > 1
? SeqStmt(dealloc_tmem_calls_)
: dealloc_tmem_calls_.back(),
Stmt()));
auto block_ptr = block.CopyOnWrite();
block_ptr->annotations.erase(attr::kLayoutMap);
block_ptr->body = SeqStmt(new_body);
return StmtExprMutator::VisitStmt_(block.get());
}
PrimExpr GetTmemOffset(const Buffer &buffer, const Array<PrimExpr> &indices) {
ICHECK(buffer->shape.size() == 2);
ICHECK(indices.size() == 2);
ICHECK(layout_map_.defined());
ICHECK(layout_map_.count(buffer))
<< "The layout of tmem buffer " << buffer->name
<< " is not defined in the layout map";
auto layout = layout_map_[buffer];
ICHECK(layout.defined());
Array<PrimExpr> tmem_phy_coords = layout->Forward(indices);
PrimExpr result =
tmem_phy_coords[0] << 16 |
tmem_phy_coords
[1]; // https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-memory-addressing
return result;
}
PrimExpr VisitExpr_(const BufferLoadNode *op) final {
// Translate tmem[logical_row, logical_col] to tmem[0] + tmem_offset
// Where
// - (logical_row, logical_col) is the logical address in the tmem buffer
// - tmem[0] is the base address allocated for the tmem buffer
// - tmem_offset = tmem_phy_coords[0]<<16 | tmem_phy_coords[1]
// where tmem_phy_coords = layout.Forward(logical_row, logical_col)
// is the physical address in the tmem buffer
auto load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto buffer = load->buffer;
auto indices = load->indices;
if (buffer_remap_.count(buffer)) {
auto new_buffer = buffer_remap_[load->buffer];
return BufferLoad(new_buffer, {0}) + GetTmemOffset(buffer, indices);
} else if (var_remap_.count(buffer->data)) {
auto new_buffer = Buffer(
var_remap_[buffer->data], tmem_dtype_, buffer->shape, buffer->strides,
buffer->elem_offset, buffer->name, buffer->data_alignment,
buffer->offset_factor, buffer->buffer_type);
return BufferLoad(new_buffer, {0}) + GetTmemOffset(buffer, indices);
}
return load;
}
Stmt VisitStmt_(const BufferStoreNode *op) final {
auto store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto buffer = store->buffer;
ICHECK(buffer.scope() != "shared.tmem")
<< "We should never directly store data into tmem!";
return store;
}
PrimExpr VisitExpr_(const CallNode *op) final {
if (op->op.same_as(builtin::tvm_access_ptr())) {
ICHECK_EQ(op->args.size(), 5U);
Var buffer_data = Downcast<Var>(op->args[1]);
if (!var_remap_.count(buffer_data)) {
return StmtExprMutator::VisitExpr_(op);
}
Var new_data = var_remap_[buffer_data];
return Call(
op->dtype, op->op,
{op->args[0], new_data, op->args[2], op->args[3], op->args[4]});
}
return StmtExprMutator::VisitExpr_(op);
}
Stmt VisitStmt_(const AttrStmtNode *op) final {
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag == "threadIdx.x") {
ICHECK(iv->dom->extent.as<IntImmNode>());
thread_var_ = iv;
}
}
return StmtExprMutator::VisitStmt_(op);
}
// Datatypes for tmem
const DataType tmem_dtype_ = DataType::UInt(32);
// This is a workaround for cpu backend,
// we need to define a thread_var for the serial loop.
IterVar thread_var_;
Map<Var, Var> var_remap_;
Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, Buffer> buffer_remap_;
// Mapping from data Var of a Buffer to Buffer, for lookup
std::unordered_map<Var, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_map_;
Map<Buffer, Layout> layout_map_;
};
PrimFunc LowerSharedTmem(PrimFunc f) {
auto target = f->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target.defined()) << "LowerSharedTmem: Require the target attribute";
SharedTmemRewriter rewriter;
f.CopyOnWrite()->body = rewriter.Rewrite(f->body);
return f;
}
namespace transform {
using namespace tir::transform;
tvm::transform::Pass LowerSharedTmem() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return tl::LowerSharedTmem(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tl.LowerSharedTmem", {});
}
TVM_FFI_STATIC_INIT_BLOCK({
namespace refl = tvm::ffi::reflection;
refl::GlobalDef().def("tl.transform.LowerSharedTmem", LowerSharedTmem);
});
} // namespace transform
} // namespace tl
} // namespace tvm
...@@ -73,6 +73,34 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout, ...@@ -73,6 +73,34 @@ static Buffer makeBufferWithLayout(const Buffer &buffer, const Layout &layout,
buffer->buffer_type); buffer->buffer_type);
} }
// The function `makeBufferWithLayout` creates a new Buffer object based on the
// given buffer and layout. It handles remapping of buffer variables, adjusts
// the storage scope if needed (e.g., from "local.fragment" to "local"), and
// computes the output shape according to the layout. For shared memory buffers,
// it also handles replication if the buffer's extent is larger than the
// layout's extent.
class LayoutRemapRewriter : public arith::IRMutatorWithAnalyzer {
public:
static Stmt Substitute(Stmt stmt, Map<Buffer, Layout> layout_remap) {
arith::Analyzer analyzer;
LayoutRemapRewriter substituter(&analyzer);
substituter.layout_remap_ = std::move(layout_remap);
return substituter.VisitStmt(stmt);
}
private:
using arith::IRMutatorWithAnalyzer::IRMutatorWithAnalyzer;
Stmt VisitStmt_(const BlockNode *op) final {
auto block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
if (op->annotations.count(attr::kLayoutMap)) {
block.CopyOnWrite()->annotations.Set(attr::kLayoutMap, layout_remap_);
}
return block;
}
Map<Buffer, Layout> layout_remap_;
};
class BufferGemmCollector : public StmtExprVisitor { class BufferGemmCollector : public StmtExprVisitor {
public: public:
BufferGemmCollector() { Clear(); } BufferGemmCollector() { Clear(); }
...@@ -227,6 +255,8 @@ public: ...@@ -227,6 +255,8 @@ public:
fptr->body = substituter.VisitStmt(f->body); fptr->body = substituter.VisitStmt(f->body);
fptr->body = fptr->body =
RemapBufferRewriter::Substitute(fptr->body, substituter.buffer_remap_); RemapBufferRewriter::Substitute(fptr->body, substituter.buffer_remap_);
fptr->body =
LayoutRemapRewriter::Substitute(fptr->body, substituter.layout_remap_);
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
Optional<Bool> opt_disable_tma_lower = Optional<Bool> opt_disable_tma_lower =
ctxt->GetConfig(kDisableTMALower, Optional<Bool>()); ctxt->GetConfig(kDisableTMALower, Optional<Bool>());
...@@ -275,7 +305,6 @@ private: ...@@ -275,7 +305,6 @@ private:
for (const auto &buffer : workspaces_) for (const auto &buffer : workspaces_)
block_ptr->alloc_buffers.push_back(buffer); block_ptr->alloc_buffers.push_back(buffer);
workspaces_.clear(); workspaces_.clear();
block_ptr->annotations.erase(attr::kLayoutMap);
return block; return block;
} }
...@@ -363,6 +392,7 @@ private: ...@@ -363,6 +392,7 @@ private:
auto new_access_ptr = access_ptr_call.CopyOnWrite(); auto new_access_ptr = access_ptr_call.CopyOnWrite();
new_access_ptr->args.Set(0, BufferLoad(new_buffer, new_indices)); new_access_ptr->args.Set(0, BufferLoad(new_buffer, new_indices));
layout_remap_.Set(new_buffer, layout_map_[load->buffer]);
} else { } else {
LOG(FATAL) << "Invalid access op for permuted layout: " << access_ptr; LOG(FATAL) << "Invalid access op for permuted layout: " << access_ptr;
} }
...@@ -430,6 +460,7 @@ private: ...@@ -430,6 +460,7 @@ private:
if (buffer_remap_.count(buffer)) { if (buffer_remap_.count(buffer)) {
auto new_indices = layout_map_[buffer]->Forward(load->indices); auto new_indices = layout_map_[buffer]->Forward(load->indices);
auto new_buffer = buffer_remap_[load->buffer]; auto new_buffer = buffer_remap_[load->buffer];
layout_remap_.Set(new_buffer, layout_map_[load->buffer]);
return BufferLoad(new_buffer, new_indices); return BufferLoad(new_buffer, new_indices);
} else if (var_remap_.count(buffer->data)) { } else if (var_remap_.count(buffer->data)) {
auto new_buffer = Buffer( auto new_buffer = Buffer(
...@@ -447,6 +478,7 @@ private: ...@@ -447,6 +478,7 @@ private:
if (buffer_remap_.count(buffer)) { if (buffer_remap_.count(buffer)) {
auto new_indices = layout_map_[buffer]->Forward(store->indices); auto new_indices = layout_map_[buffer]->Forward(store->indices);
auto new_buffer = buffer_remap_[store->buffer]; auto new_buffer = buffer_remap_[store->buffer];
layout_remap_.Set(new_buffer, layout_map_[store->buffer]);
return BufferStore(new_buffer, store->value, new_indices); return BufferStore(new_buffer, store->value, new_indices);
} else if (var_remap_.count(buffer->data)) { } else if (var_remap_.count(buffer->data)) {
auto new_buffer = Buffer( auto new_buffer = Buffer(
...@@ -547,6 +579,7 @@ private: ...@@ -547,6 +579,7 @@ private:
Target target_; Target target_;
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
Map<Buffer, Layout> layout_map_; Map<Buffer, Layout> layout_map_;
Map<Buffer, Layout> layout_remap_;
Map<Buffer, Buffer> buffer_remap_; Map<Buffer, Buffer> buffer_remap_;
// This is a workaround for cpu backend, // This is a workaround for cpu backend,
// we need to define a thread_var for the serial loop. // we need to define a thread_var for the serial loop.
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <tvm/tir/stmt_functor.h> #include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h> #include <tvm/tir/transform.h>
#include "../op/builtin.h"
#include <utility> #include <utility>
#include "../target/utils.h" #include "../target/utils.h"
...@@ -35,6 +36,110 @@ bool MayConflict(const Region &region1, const Region &region2) { ...@@ -35,6 +36,110 @@ bool MayConflict(const Region &region1, const Region &region2) {
return true; return true;
} }
class TmemLoadCollector : public StmtExprVisitor {
public:
TmemLoadCollector() {}
Buffer result;
private:
void VisitExpr_(const BufferLoadNode *op) {
Buffer buf = op->buffer;
if (buf->data->type_annotation.as<PointerTypeNode>()->storage_scope ==
"shared") {
// We only care about shared.tmem buffers
ICHECK(!result.defined())
<< "TmemLoadCollector: More than one shared buffer visited";
result = buf;
}
}
};
/*!
* \brief Build the dependency chain between async operations and their
* corresponding buffers & synchronizations.
*
* Example:
* If we encounter the following pattern:
*
* tcgen5mma_gemm_ts(..., mbar, ...)
* mbarrier_wait_parity(mbar)
*
* The builder will link the mbarrier to the buffers used in the
* TCGEN5MMA
*/
class AsyncDependencyChainBuilder : public StmtExprVisitor {
public:
AsyncDependencyChainBuilder(Map<Var, Buffer> buffer_data_to_buffer)
: buffer_data_to_buffer_(buffer_data_to_buffer) {}
std::unordered_map<const BufferNode *, Array<BufferRegion>>
mbar_to_buffer_reads_;
std::unordered_map<const BufferNode *, Array<BufferRegion>>
mbar_to_buffer_writes_;
private:
Map<Var, Buffer> buffer_data_to_buffer_;
void VisitExpr_(const CallNode *op) final {
auto args = op->args;
if (op->op.same_as(builtin::call_extern())) {
std::string func_name_with_template = args[0].as<StringImmNode>()->value;
std::size_t le_pos = func_name_with_template.find_first_of('<');
std::string func_name = le_pos == std::string::npos
? func_name_with_template
: func_name_with_template.substr(0, le_pos);
if (func_name == "tl::utcmma_gemm_ts" ||
func_name == "tl::utcmma_gemm_ss") {
// TCGEN5MMA
auto get_buf_from_access_ptr_call =
[&](const PrimExpr &expr) -> Buffer {
auto call = expr.as<CallNode>();
ICHECK(call);
ICHECK(call->op.same_as(builtin::tvm_access_ptr()));
auto var = call->args[1].as<VarNode>();
ICHECK(var);
auto it = buffer_data_to_buffer_.find(GetRef<Var>(var));
ICHECK(it != buffer_data_to_buffer_.end());
return (*it).second;
};
Buffer a_buf = get_buf_from_access_ptr_call(args[1]);
Buffer b_buf = get_buf_from_access_ptr_call(args[2]);
Buffer mbar_buf = get_buf_from_access_ptr_call(args[4]);
TmemLoadCollector tmem_collector;
tmem_collector(args[3]);
ICHECK(tmem_collector.result.defined())
<< "TmemLoadCollector: No tmem buffer load found in the TCGEN5MMA "
"call";
Buffer c_buf = tmem_collector.result;
PrimExpr clear_accum = args[5];
mbar_to_buffer_reads_[mbar_buf.get()].push_back(
BufferRegion::FullRegion(a_buf));
mbar_to_buffer_reads_[mbar_buf.get()].push_back(
BufferRegion::FullRegion(b_buf));
mbar_to_buffer_writes_[mbar_buf.get()].push_back(
BufferRegion::FullRegion(c_buf));
auto analyzer = std::make_shared<arith::Analyzer>();
if (!analyzer->CanProveEqual(clear_accum, Bool(true))) {
mbar_to_buffer_reads_[mbar_buf.get()].push_back(
BufferRegion::FullRegion(c_buf));
}
}
// TODO (lei) Link wgmma to buffers and tl.wait_wgmma
} else if (op->op.same_as(tir::builtin::if_then_else())) {
const PrimExpr &then_expr = args[1];
const PrimExpr &else_expr = args[2];
this->VisitExpr(then_expr);
this->VisitExpr(else_expr);
} else {
StmtExprVisitor::VisitExpr_(op);
}
}
};
/*! /*!
* \brief Detect if a statement follows the global memory copy pattern: * \brief Detect if a statement follows the global memory copy pattern:
* 1. Contains exactly one buffer store operation * 1. Contains exactly one buffer store operation
...@@ -43,8 +148,10 @@ bool MayConflict(const Region &region1, const Region &region2) { ...@@ -43,8 +148,10 @@ bool MayConflict(const Region &region1, const Region &region2) {
*/ */
class BufferRegionCollector : public StmtExprVisitor { class BufferRegionCollector : public StmtExprVisitor {
public: public:
BufferRegionCollector(Map<Var, Buffer> buffer_data_to_buffer) BufferRegionCollector(Map<Var, Buffer> buffer_data_to_buffer,
: buffer_data_to_buffer_(std::move(buffer_data_to_buffer)) {} const AsyncDependencyChainBuilder &chain_builder)
: buffer_data_to_buffer_(buffer_data_to_buffer),
chain_builder_(chain_builder) {}
Array<BufferRegion> GetReads() const { return reads_; } Array<BufferRegion> GetReads() const { return reads_; }
...@@ -117,6 +224,23 @@ private: ...@@ -117,6 +224,23 @@ private:
for (auto i = 1; i < op->args.size(); i++) { for (auto i = 1; i < op->args.size(); i++) {
this->VisitExpr(op->args[i]); this->VisitExpr(op->args[i]);
} }
} else if (op->op.same_as(tl::mbarrier_wait_parity())) {
ICHECK(args[0].as<BufferLoadNode>());
Buffer mbar_buf = args[0].as<BufferLoadNode>()->buffer;
auto buffer_reads =
chain_builder_.mbar_to_buffer_reads_.find(mbar_buf.get());
auto buffer_writes =
chain_builder_.mbar_to_buffer_writes_.find(mbar_buf.get());
if (buffer_reads != chain_builder_.mbar_to_buffer_reads_.end()) {
reads_.insert(reads_.end(), buffer_reads->second.begin(),
buffer_reads->second.end());
}
if (buffer_writes != chain_builder_.mbar_to_buffer_writes_.end()) {
writes_.insert(
writes_.end(),
chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).begin(),
chain_builder_.mbar_to_buffer_writes_.at(mbar_buf.get()).end());
}
} else { } else {
StmtExprVisitor::VisitExpr_(op); StmtExprVisitor::VisitExpr_(op);
} }
...@@ -135,6 +259,7 @@ private: ...@@ -135,6 +259,7 @@ private:
} }
private: private:
AsyncDependencyChainBuilder chain_builder_;
Map<Var, Buffer> buffer_data_to_buffer_; Map<Var, Buffer> buffer_data_to_buffer_;
Array<BufferRegion> reads_; Array<BufferRegion> reads_;
Array<BufferRegion> writes_; Array<BufferRegion> writes_;
...@@ -200,12 +325,15 @@ private: ...@@ -200,12 +325,15 @@ private:
} }
}; };
PipelineStageInfo MakePipelineStageInfo(Stmt stmt, int idx) { PipelineStageInfo
MakePipelineStageInfo(Stmt stmt, int idx,
AsyncDependencyChainBuilder &chain_builder) {
Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"", Block block(/*iter_vars=*/{}, /*reads=*/{}, /*writes=*/{}, /*name_hint=*/"",
/*body*/ std::move(stmt)); /*body*/ std::move(stmt));
Array<Array<BufferRegion>> access = Array<Array<BufferRegion>> access =
GetBlockReadWriteRegion(block, buffer_data_to_buffer_); GetBlockReadWriteRegion(block, buffer_data_to_buffer_);
auto collector = BufferRegionCollector(buffer_data_to_buffer_); auto collector =
BufferRegionCollector(buffer_data_to_buffer_, chain_builder);
collector(block); collector(block);
PipelineStageInfo pinfo; PipelineStageInfo pinfo;
pinfo.reads = std::move(collector.GetReads()); pinfo.reads = std::move(collector.GetReads());
...@@ -299,9 +427,13 @@ private: ...@@ -299,9 +427,13 @@ private:
CHECK(num_stages >= 1); CHECK(num_stages >= 1);
CHECK(loop->kind == ForKind::kSerial); CHECK(loop->kind == ForKind::kSerial);
AsyncDependencyChainBuilder chain_builder(buffer_data_to_buffer_);
chain_builder(pipeline_body);
std::vector<PipelineStageInfo> pipeline_stage_infos; std::vector<PipelineStageInfo> pipeline_stage_infos;
for (size_t i = 0; i < pipeline_body_seq->size(); i++) { for (size_t i = 0; i < pipeline_body_seq->size(); i++) {
auto pinfo = MakePipelineStageInfo(pipeline_body_seq->seq[i], i); auto pinfo =
MakePipelineStageInfo(pipeline_body_seq->seq[i], i, chain_builder);
pipeline_stage_infos.push_back(std::move(pinfo)); pipeline_stage_infos.push_back(std::move(pinfo));
} }
......
...@@ -49,7 +49,8 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo ...@@ -49,7 +49,8 @@ def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="flo
def assert_matmul_codegen(M=1024, N=1024, K=1024, block_M=128, block_N=128, block_K=32): def assert_matmul_codegen(M=1024, N=1024, K=1024, block_M=128, block_N=128, block_K=32):
func = matmul(M, N, K, block_M, block_N, block_K) func = matmul(M, N, K, block_M, block_N, block_K)
artifact = tilelang.lower(func, target="c") with tvm.target.Target("c"):
artifact = tilelang.lower(func)
code = artifact.kernel_source code = artifact.kernel_source
...@@ -101,7 +102,8 @@ def test_matmul_compile(): ...@@ -101,7 +102,8 @@ def test_matmul_compile():
M, N, K = 1024, 512, 512 M, N, K = 1024, 512, 512
block_M, block_N, block_K = M // 4, N // 4, K // 4 block_M, block_N, block_K = M // 4, N // 4, K // 4
cpu_func = matmul_jit_test(M, N, K, block_M, block_N, block_K) cpu_func = matmul_jit_test(M, N, K, block_M, block_N, block_K)
complied_fun = tilelang.compile(cpu_func, -1, execution_backend="ctypes", target="c") with tvm.target.Target("c"):
complied_fun = tilelang.compile(cpu_func, -1, execution_backend="ctypes")
in_dtype = "float16" in_dtype = "float16"
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype)) A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype))
......
...@@ -82,6 +82,7 @@ def run_gemm( ...@@ -82,6 +82,7 @@ def run_gemm(
) )
kernel = tilelang.compile(program, out_idx=[2]) kernel = tilelang.compile(program, out_idx=[2])
print(kernel.get_kernel_source())
profiler = kernel.get_profiler() profiler = kernel.get_profiler()
def ref_program(A, B): def ref_program(A, B):
......
...@@ -77,16 +77,17 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype): ...@@ -77,16 +77,17 @@ def test_loop_tail_split(block_M, block_N, block_K, threads, vec_load_b, dtype):
bx * block_N + t % (block_N // vec_load_b) * bx * block_N + t % (block_N // vec_load_b) *
(block_N // vec_load_b) + vec], T.float16(0)) (block_N // vec_load_b) + vec], T.float16(0))
mod = tvm.tir.transform.BindTarget(auto_target)(Before) with tvm.target.Target(auto_target):
mod = tl.transform.LayoutInference()(mod) mod = tvm.tir.transform.BindTarget(auto_target)(Before)
mod = tvm.tir.transform.Simplify()(mod) mod = tl.transform.LayoutInference()(mod)
ref_mod = tvm.tir.transform.BindTarget(auto_target)(After) mod = tvm.tir.transform.Simplify()(mod)
ref_mod = tvm.tir.transform.Simplify()(ref_mod) ref_mod = tvm.tir.transform.BindTarget(auto_target)(After)
# Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass ref_mod = tvm.tir.transform.Simplify()(ref_mod)
# This loop is "for vec in T.parallel(1)", # Note(tzj): The structures are equal except one more "for" loop after the LayoutInference pass
# Since the loop var "vec" is never used in the loop body, it does not affect the correctness # This loop is "for vec in T.parallel(1)",
tvm.ir.structural_equal(mod, ref_mod) # Since the loop var "vec" is never used in the loop body, it does not affect the correctness
# tvm.ir.assert_structural_equal(mod, ref_mod) tvm.ir.structural_equal(mod, ref_mod)
# tvm.ir.assert_structural_equal(mod, ref_mod)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -32,7 +32,8 @@ def vectorize_access_legalize(M: int = 64, N: int = 64): ...@@ -32,7 +32,8 @@ def vectorize_access_legalize(M: int = 64, N: int = 64):
def assert_vectorize_access(M: int = 64, N: int = 64): def assert_vectorize_access(M: int = 64, N: int = 64):
func, expected = vectorize_access_legalize(M, N) func, expected = vectorize_access_legalize(M, N)
mod = tvm.IRModule({func.attrs["global_symbol"]: func}) mod = tvm.IRModule({func.attrs["global_symbol"]: func})
transformed = tl.transform.LegalizeVectorizedLoop()(mod) with tvm.target.Target("cuda"):
transformed = tl.transform.LegalizeVectorizedLoop()(mod)
tvm.ir.assert_structural_equal(transformed["main"].body, expected.body) tvm.ir.assert_structural_equal(transformed["main"].body, expected.body)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment