Commit c8fc0cbb authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Backend][WebGPU] Support WebGPU WGSL code generation (#86)

* bump version into v0.1.0

* [Enhancement] Add custom develop command for editable installs and update .gitignore

* [Documentation] Update README to include system dependencies installation instructions

* [Build] Update setup.py to support library file copying for both release and develop modes

* [Build] Refactor library file copying logic in setup.py

* [Documentation] Remove unnecessary install section header in Installation.md

* [Build] Add tox configuration and local distribution script for multi-Python version support

* [Build] Improve git submodule update function with better error handling

* [Build] Update LLVM configuration path in ROCm installation script

* [Build] Add .tox/ to .gitignore for tox testing environment

* [Build] Add support for TVM prebuild path configuration in CMakeLists.txt

* [Cleanup] Remove unused TVM runtime error codes header

* [Cleanup] Fix TVM grid constant type reference in CUDA module

* [Cleanup] Remove unused customized_code function from IR module

* [Feature] Add TileLang thread synchronization and storage access analysis passes

* [Build] Reorder DLL search path directories for more flexible library loading

* [Refactor] Improve thread synchronization and library path handling

- Rename ThreadSync and TileLangThreadSync functions in C++ code
- Update Python docstring for ThreadSync with more detailed description
- Reorder library path detection in tilelang environment setup
- Minor comment and code cleanup in CUDA and warp specialization modules

* [Refactor] Improve thread synchronization code style and formatting

- Standardize pointer type spacing in storage_access.h and storage_access.cc
- Update whitespace and indentation in thread_storage_sync.cc
- Reorder include statements in thread_partial_sync.cc
- Minor code formatting improvements across thread synchronization files

* [Refactor] Fix global function registration for ThreadSync

- Correct global function registration to use ThreadSync instead of TileLangThreadSync
- Update TVM global registration to match recent refactoring efforts

* [Refactor] Simplify ThreadSync global function registration

- Remove unnecessary whitespace in global function registration
- Compact the TVM global registration line for ThreadSync

* [Feature] Add WebGPU code generation support in TileLang

- Implement WebGPU code generator (codegen_webgpu.cc and codegen_webgpu.h)
- Add WebGPU target support in lower.py and target.py
- Update CMakeLists.txt to include WebGPU codegen source files
- Introduce WebGPU-specific code generation for WGSL shader language

* [Refactor] Improve WebGPU code generation formatting and readability

- Enhance code formatting in codegen_webgpu.cc and codegen_webgpu.h
- Standardize pointer type spacing and indentation
- Improve line breaks and reduce line length for better readability
- Minor code style improvements in WebGPU code generation

* [Test] Add WebGPU matrix multiplication code generation test

- Implement test_webgpu_codegen.py for WebGPU matrix multiplication
- Add assert_gemm_codegen function to validate WebGPU code generation
- Include basic matrix multiplication kernel test case

* Update README with WebGPU codegen support announcement
parent ec84188f
......@@ -110,6 +110,8 @@ tilelang_file_glob(GLOB TILE_LANG_SRCS
src/target/utils.cc
src/target/codegen_cpp.cc
src/target/rt_mod_cpp.cc
# webgpu doesn't have system dependency
src/target/codegen_webgpu.cc
)
# Include CUDA source files if CUDA is enabled
......
......@@ -11,6 +11,7 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to
<img src=./images/MatmulExample.png />
## Latest News
- 02/15/2025 ✨: Added WebGPU codegen support, see [Pull Request #86](https://github.com/tile-ai/tilelang/pull/86)!
- 02/12/2025 ✨: Excited to announce the release of [v0.1.0](https://github.com/tile-ai/tilelang/releases/tag/v0.1.0)!
- 02/10/2025 🚀: Added debug tools for TileLang—`T.print` for printing variables/buffers ([docs](https://tilelang.tile-ai.cn/tutorials/debug_tools_for_tilelang.html)) and a memory layout plotter ([examples/plot_layout](./examples/plot_layout)).
- 01/20/2025 ✨: We are excited to announce that tile-lang, a dsl for high performance AI workloads, is now open source and available to the public!
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_webgpu.cc
*/
#include "codegen_webgpu.h"
#include <tvm/arith/analyzer.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/transform.h>
#include <algorithm>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "arith/pattern_match.h"
#include "runtime/meta_data.h"
#include "runtime/thread_storage_scope.h"
#include "target/build_common.h"
namespace tvm {
namespace codegen {
// WebGPU Info
struct WebGPUWorkGroupInfo {
int workgroup_size[3] = {1, 1, 1};
// whether we have ref to block index z is used.
bool has_block_index_z{false};
// set of handles that have write access
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> write_access_set;
};
class WebGPUWorkgroupInfoCollector : public StmtExprVisitor {
public:
static WebGPUWorkGroupInfo Collect(const Stmt &stmt) {
WebGPUWorkgroupInfoCollector collector;
collector(stmt);
return collector.info_;
}
private:
void VisitExpr_(const VarNode *op) final {
StmtExprVisitor::VisitExpr_(op);
Var buffer_var = GetRef<Var>(op);
if (buffer_var.dtype().is_handle()) {
info_.write_access_set.insert(buffer_var);
}
}
void VisitStmt_(const BufferStoreNode *op) final {
StmtExprVisitor::VisitStmt_(op);
info_.write_access_set.insert(op->buffer->data);
}
void VisitStmt_(const AttrStmtNode *op) final {
// record workgroup size
if (op->attr_key == tir::attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
if (iv->thread_tag.length() != 0) {
runtime::ThreadScope ts = runtime::ThreadScope::Create(iv->thread_tag);
if (ts.rank == 1) {
ICHECK_GE(ts.dim_index, 0)
<< "vthread should have been optimized out by here";
ICHECK_LT(ts.dim_index, 3);
auto *sizeptr = op->value.as<tir::IntImmNode>();
ICHECK(sizeptr) << "CodeGenTileLangWebGPU: only allows constant "
"thread group size "
<< " get " << op->value;
info_.workgroup_size[ts.dim_index] =
static_cast<uint32_t>(sizeptr->value);
} else if (ts.rank == 0) {
if (ts.dim_index == 2) {
info_.has_block_index_z = true;
}
}
}
}
// normal operation
StmtExprVisitor::VisitStmt_(op);
}
WebGPUWorkGroupInfo info_;
};
std::string CodeGenTileLangWebGPU::Finish() {
// Using f16 requires enable directive
if (enable_fp16_) {
header_stream << "enable f16;\n\n";
}
// WebGPU WGSL doesn't support #include.
// We must explicitly include all the templates here.
return header_stream.str() + decl_stream.str() + this->fwd_decl_stream.str() +
stream.str();
}
void CodeGenTileLangWebGPU::InitFuncState(const PrimFunc &f) {
CodeGenC::InitFuncState(f);
// analyze the data;
for (Var arg : f->params) {
if (arg.dtype().is_handle()) {
alloc_storage_scope_[arg.get()] = "global";
}
}
}
CodeGenTileLangWebGPU::CodeGenTileLangWebGPU(Target target) : target_(target) {}
runtime::FunctionInfo
CodeGenTileLangWebGPU::AddFunction(const PrimFunc &f, bool skip_readonly_decl) {
// clear previous generated state.
this->InitFuncState(f);
// reserve keywords
name_supply_->ReserveName("var");
name_supply_->ReserveName("let");
name_supply_->ReserveName("const");
// skip the first underscore, so SSA variable starts from
name_supply_->FreshName("v_");
// Setup the thread group info.
ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
ICHECK_EQ(name_supply_->FreshName("gridDim"), "gridDim");
// add to alloc buffer type.
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined()) << "CodeGenTileLangWebGPU: Expect PrimFunc "
"to have the global_symbol attribute";
header_stream << "//----------------------------------------\n"
<< "// Function: " << global_symbol.value() << "\n"
<< "//----------------------------------------\n";
runtime::FunctionInfo func_info;
func_info.name = global_symbol.value();
WebGPUWorkGroupInfo info = WebGPUWorkgroupInfoCollector::Collect(f->body);
std::vector<Var> pod_args;
int num_buffer = 0;
// add param_access modes info to launch params
std::ostringstream os_param_access;
os_param_access << "paramWriteAccess:[";
// setup buffer argumemts
for (Var arg : f->params) {
DataType t = arg.dtype();
func_info.arg_types.push_back(t);
if (t.is_handle()) {
auto *ptr = arg->type_annotation.as<PointerTypeNode>();
ICHECK(ptr) << "All handles passed to the CodeGenTileLangWebGPU must "
"have a type_annotation as a "
"PointerType, "
<< "and must point to a PrimType";
auto *prim = ptr->element_type.as<PrimTypeNode>();
ICHECK(prim) << "All handles passed to the CodeGenTileLangWebGPU must "
"have a type_annotation as a "
"PointerType, "
<< "and must point to a PrimType";
DataType value_storage_type = prim->dtype;
if (value_storage_type == DataType::Bool()) {
// We need a physically addressable buffer type to support boolean
// tensors. The loaded byte is cast to bool inside the LoadNode visitor
// below.
value_storage_type =
boolean_storage_type_.with_lanes(value_storage_type.lanes());
}
std::string vid = AllocVarID(arg.get());
std::string access_mode;
if (num_buffer != 0) {
os_param_access << ",";
}
if (skip_readonly_decl || info.write_access_set.count(arg)) {
access_mode = "read_write";
os_param_access << "1";
} else {
access_mode = "read";
os_param_access << "0";
}
// add extra access mode info to launch params
this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") "
<< "var<storage, " << access_mode << "> " << vid
<< " : array<";
this->PrintType(value_storage_type, this->decl_stream);
this->decl_stream << ">;\n";
} else {
pod_args.push_back(arg);
}
}
// Store all pod arguments in a single buffer of int32
// do bitcast to change to other data types
// always pass gridDimX in to get around of the 65535 gridDim
// restrictions in some platforms
std::string type_pod_args = name_supply_->FreshName("PODArgs");
std::string val_pod_args = name_supply_->FreshName("podArgs");
std::string packGridDimX = name_supply_->FreshName("packGridDimX");
this->decl_stream << "\nstruct " << type_pod_args << " {\n";
for (size_t i = 0; i < pod_args.size(); ++i) {
Var v = pod_args[i];
ICHECK(!v.dtype().is_handle());
std::string vid = AllocVarID(v.get());
if (v.dtype() == DataType::Int(32)) {
this->decl_stream << " " << vid << ": i32";
} else if (v.dtype() == DataType::UInt(32)) {
this->decl_stream << " " << vid << ": u32";
} else if (v.dtype() == DataType::Float(32)) {
this->decl_stream << " " << vid << ": f32";
} else {
LOG(FATAL) << "Do not support pod argument type " << v.dtype();
}
this->decl_stream << ",\n";
// value ref
std::ostringstream vref;
vref << val_pod_args << "." << vid;
var_idmap_[v.get()] = vref.str();
}
this->decl_stream << " " << packGridDimX << ": u32\n}\n";
this->decl_stream << "@group(0) @binding(" << num_buffer++ << ") "
<< "var<uniform> " << val_pod_args << " : " << type_pod_args
<< ";\n\n";
// setup thread tags and param access in launch param tags;
if (auto opt = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams)) {
for (const auto &thread_tag : opt.value()) {
func_info.launch_param_tags.push_back(thread_tag);
}
}
os_param_access << "]";
func_info.launch_param_tags.push_back(os_param_access.str());
ICHECK(!info.has_block_index_z)
<< "blockIdx.z is not supported in WebGPU to accomodate large blockIdx.x";
// anotate workgroup
this->stream << "@compute @workgroup_size(" << info.workgroup_size[0] << ", "
<< info.workgroup_size[1] << ", " << info.workgroup_size[2]
<< ")\n";
// add to alloc buffer type.
// Function header.
this->stream << "fn " << func_info.name << "(\n"
<< " @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
<< " @builtin(num_workgroups) gridDim : vec3<u32>,\n"
<< " @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
<< ") {\n";
// skip out of bound grids
this->stream << " if (blockIdx.z * gridDim.x + blockIdx.x > " // NOLINT(*)
<< val_pod_args << "." << packGridDimX << ") { return; }\n";
// the function scope.
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n\n";
return func_info;
}
void CodeGenTileLangWebGPU::BindThreadIndex(const IterVar &iv) {
ICHECK(!var_idmap_.count(iv->var.get()));
std::ostringstream os;
PrintType(iv->var.dtype(), os);
if (iv->thread_tag == "blockIdx.x") {
// WebGPU have restriction to limit the maximum size of blockId.x to be
// 65535 We allow runtime to spread the load out to blockIdx.z so it can be
// a large number.
os << "(blockIdx.z * gridDim.x + blockIdx.x)";
std::string tidx = os.str();
std::string aggregated_bidx = SSAGetID(os.str(), iv->var.dtype());
var_idmap_[iv->var.get()] = aggregated_bidx;
} else {
os << "(" << iv->thread_tag << ")";
std::string tidx = os.str();
this->MarkConst(tidx);
var_idmap_[iv->var.get()] = tidx;
}
}
void CodeGenTileLangWebGPU::PrintType(DataType t,
std::ostream &os) { // NOLINT(*)
int lanes = t.lanes();
if (t.is_handle()) {
LOG(FATAL) << "Cannot print handle type in WebGPU";
}
if (t.is_void()) {
os << "void";
return;
}
if (t == DataType::Bool()) {
os << "bool";
return;
}
if (lanes != 1) {
// ICHECK(lanes >= 2 && lanes <= 4) << "CodeGenTileLangWebGPU: only allows
// vector with lanes in {2, 3, 4} " << " while lanes is " << lanes;
os << "vec" << lanes << "<";
}
if (t.is_float()) {
ICHECK(t.bits() == 16 || t.bits() == 32)
<< "CodeGenTileLangWebGPU: only support f16 or f32";
if (t.bits() == 16) {
// Using f16 requires enable directive
enable_fp16_ = true;
}
os << "f" << t.bits();
} else if (t.is_uint()) {
ICHECK(t.bits() != 64) << "CodeGenTileLangWebGPU: do not support u64";
os << "u" << t.bits();
} else if (t.is_int()) {
ICHECK(t.bits() != 64) << "CodeGenTileLangWebGPU: do not support i64";
os << "i" << t.bits();
} else {
LOG(FATAL) << "CodeGenTileLangWebGPU: Cannot convert type " << t
<< " to WebGPU type";
}
if (lanes != 1) {
os << ">";
}
}
void CodeGenTileLangWebGPU::PrintStorageSync(const CallNode *op) {
const std::string &sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
this->PrintIndent();
this->stream << "workgroupBarrier();\n";
} else if (sync == "shared") {
this->PrintIndent();
this->stream << "workgroupBarrier();\n";
} else if (sync == "global") {
LOG(FATAL) << "global barrier not supported";
}
}
void CodeGenTileLangWebGPU::PrintSSAAssign(const std::string &target,
const std::string &src,
DataType type) {
stream << "let " << target << " : ";
PrintType(type, stream);
stream << " = " << src << ";\n";
}
void CodeGenTileLangWebGPU::VisitExpr_(const BroadcastNode *op,
std::ostream &os) { // NOLINT(*)
std::string v = PrintExpr(op->value);
int lanes = op->dtype.lanes();
PrintType(op->dtype, os);
os << "(";
for (int i = 0; i < lanes; ++i) {
if (i != 0)
os << ", ";
os << v;
}
os << ')';
}
PrimExpr CodeGenTileLangWebGPU::EnforceU32(PrimExpr value) {
return cast(DataType::UInt(32, value.dtype().lanes()), value);
}
void CodeGenTileLangWebGPU::VisitExpr_(const CallNode *op,
std::ostream &os) { // NOLINT(*)
if (op->op.same_as(builtin::reinterpret())) {
// generate bitcast<TYPE>(ARG)
os << "bitcast<";
this->PrintType(op->dtype, os);
os << ">(";
this->PrintExpr(op->args[0], os);
os << ")";
} else if (op->op.same_as(builtin::shift_right())) {
os << '(';
this->PrintExpr(op->args[0], os);
os << ">>";
// WebGPU requires shift bits to be u32.
this->PrintExpr(EnforceU32(op->args[1]), os);
os << ')';
} else if (op->op.same_as(builtin::shift_left())) {
os << '(';
this->PrintExpr(op->args[0], os);
os << "<<";
// WebGPU requires shift bits to be u32.
this->PrintExpr(EnforceU32(op->args[1]), os);
os << ')';
} else if (op->op.same_as(builtin::if_then_else())) {
// conditional that skips eval if cond evals to false
std::string result = name_supply_->FreshName("condval");
std::string cond = PrintExpr(op->args[0]);
this->PrintIndent();
this->stream << "var " << result << " : ";
PrintType(op->dtype, this->stream);
this->stream << ";\n";
this->PrintIndent();
this->stream << "if (" << cond << ") {\n";
{
int then_scope = this->BeginScope();
std::string true_val = PrintExpr(op->args[1]);
this->PrintIndent();
this->stream << result << " = " << true_val << ";\n} else {\n";
this->EndScope(then_scope);
}
{
int else_scope = this->BeginScope();
std::string false_val = PrintExpr(op->args[2]);
this->PrintIndent();
this->stream << result << " = " << false_val << ";\n}\n";
this->EndScope(else_scope);
}
os << result;
} else {
CodeGenC::VisitExpr_(op, os);
}
}
void CodeGenTileLangWebGPU::VisitExpr_(const CastNode *op,
std::ostream &os) { // NOLINT(*)
PrintType(op->dtype, os);
os << "(" << PrintExpr(op->value) << ")";
}
void CodeGenTileLangWebGPU::VisitExpr_(const SelectNode *op,
std::ostream &os) { // NOLINT(*)
os << "select(" << PrintExpr(op->false_value) << ", "
<< PrintExpr(op->true_value) << ", " << PrintExpr(op->condition) << ")";
}
void CodeGenTileLangWebGPU::VisitExpr_(const IntImmNode *op,
std::ostream &os) { // NOLINT(*)
if (op->dtype.bits() == 32) {
std::ostringstream temp;
if (op->dtype.is_int()) {
temp << op->value << "i";
} else {
ICHECK(op->dtype.is_uint());
temp << op->value << "u";
}
this->MarkConst(temp.str());
os << temp.str();
} else {
this->PrintType(op->dtype, os);
os << "(" << op->value << ")";
}
}
void CodeGenTileLangWebGPU::VisitExpr_(const FloatImmNode *op,
std::ostream &os) { // NOLINT(*)
std::ostringstream temp;
temp << std::scientific << op->value;
if (op->dtype.bits() == 32) {
temp << 'f';
} else if (op->dtype.bits() == 16) {
// Using f16 requires enable directive
enable_fp16_ = true;
temp << 'h';
} else {
LOG(FATAL) << "Unsupported floating point bits " << op->dtype.bits();
}
MarkConst(temp.str());
os << temp.str();
}
void CodeGenTileLangWebGPU::VisitExpr_(const BufferLoadNode *op,
std::ostream &os) { // NOLINT(*)
// NOTE: direct impl of load/store for correctness
// Each printing stmt must stand on their own after all preprocessing steps
// to ensure correctness in the case of nested-expression
// do not try to lift common printings from each case
ICHECK_EQ(op->indices.size(), 1)
<< "Load from non-flat memory not supported.";
DataType value_dtype = op->dtype;
PrimExpr index = op->indices[0];
Var buffer_var = op->buffer->data;
DataType element_dtype = op->buffer->dtype;
int lanes = op->dtype.lanes();
std::string buffer_vid = GetVarID(buffer_var.get());
if (value_dtype.lanes() == element_dtype.lanes()) {
// Direct buffer loading
// Special handle bool loading
if (value_dtype == DataType::Bool()) {
this->PrintType(value_dtype, os);
os << "(";
} else {
ICHECK(value_dtype == element_dtype);
}
ICHECK_EQ(index.dtype().lanes(), 1);
os << buffer_vid << "[" << this->PrintExpr(index) << "]";
// Special handle bool loading
if (value_dtype == DataType::Bool()) {
os << ")";
}
} else {
// Vector load from scalar buffer
ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array";
ICHECK(value_dtype.element_of() == element_dtype)
<< "WebGPU vector loading requires base type to match";
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, op->dtype.lanes()).Match(index)) {
// vec3<f32>(buf[base + 0], buf[base + 1], buf[base + 2]);
std::string base_vid =
SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype());
PrintType(element_dtype.with_lanes(value_dtype.lanes()), os);
os << "(";
for (int i = 0; i < lanes; ++i) {
if (i != 0)
os << ", ";
os << buffer_vid << "[" << base_vid << " + " << i << "]";
}
os << ")";
} else {
// vec3<f32>(buf[index[0]], buf[index[1]], buf[index[2]]);
std::string index_vid = SSAGetID(PrintExpr(index), index.dtype());
PrintType(element_dtype.with_lanes(value_dtype.lanes()), os);
os << "(";
for (int i = 0; i < lanes; ++i) {
if (i != 0)
os << ", ";
os << buffer_vid << "[" << index_vid << "[" << i << "]]";
}
os << ")";
}
}
}
void CodeGenTileLangWebGPU::VisitStmt_(const LetStmtNode *op) {
// use ssa form.
if (print_ssa_form_) {
std::string value = PrintExpr(op->value);
ICHECK(!var_idmap_.count(op->var.get()));
var_idmap_[op->var.get()] = value;
} else {
PrintIndent();
std::string value = PrintExpr(op->value);
this->stream << "let " << AllocVarID(op->var.get()) << " : ";
PrintType(op->var.dtype(), this->stream);
this->stream << " = " << value << ";\n";
}
PrintStmt(op->body);
}
void CodeGenTileLangWebGPU::VisitStmt_(const BufferStoreNode *op) {
CHECK_EQ(op->indices.size(), 1) << "Store to non-flat memory not supported.";
DataType value_dtype = op->value.dtype();
DataType element_dtype = op->buffer->dtype;
PrimExpr index = op->indices[0];
Var buffer_var = op->buffer->data;
std::string buffer_vid = GetVarID(buffer_var.get());
if (value_dtype.lanes() == element_dtype.lanes()) {
// must execute print expr first
// so we won't have recursive append to stream
std::string index_vid = PrintExpr(index);
std::string value_vid = PrintExpr(op->value);
// now print the assignment line.
this->PrintIndent();
stream << buffer_vid << "[" << index_vid << "] = ";
// special explicit conversion of bool
if (value_dtype == DataType::Bool()) {
PrintType(element_dtype, stream);
stream << "(";
} else {
ICHECK(value_dtype == element_dtype);
}
stream << value_vid;
// Special handle bool store
if (value_dtype == DataType::Bool()) {
stream << ")";
}
stream << ";\n";
} else {
// Vector store into scalar buffer
ICHECK_EQ(element_dtype.lanes(), 1) << "Can only vector load scalar array";
ICHECK(value_dtype.element_of() == element_dtype)
<< "WebGPU vector stire requires base type to match";
std::string value_vid = PrintExpr(op->value);
arith::PVar<PrimExpr> base;
if (arith::ramp(base, 1, value_dtype.lanes()).Match(index)) {
// buf[base + 0] = value[0]
// buf[base + 1] = value[1]
std::string base_vid =
SSAGetID(PrintExpr(base.Eval()), base.Eval().dtype());
for (int i = 0; i < value_dtype.lanes(); ++i) {
this->PrintIndent();
stream << buffer_vid << "[" << base_vid << " + " << i
<< "] = " << value_vid << "[" << i << "];\n";
}
} else {
// buf[index[0]] = value[0]
// buf[index[1]] = value[1]
std::string index_vid = SSAGetID(PrintExpr(index), index.dtype());
for (int i = 0; i < value_dtype.lanes(); ++i) {
this->PrintIndent();
stream << buffer_vid << "[" << index_vid << "[" << i
<< "]] = " << value_vid << "[" << i << "];\n";
}
}
}
}
void CodeGenTileLangWebGPU::VisitStmt_(const AllocateNode *op) {
ICHECK(!is_zero(op->condition));
std::string vid = AllocVarID(op->buffer_var.get());
size_t constant_size = op->ConstantAllocationSize();
ICHECK_GT(constant_size, 0)
<< "Can only handle constant size stack allocation for now";
auto storage_scope =
runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (storage_scope.rank == runtime::StorageRank::kShared) {
this->decl_stream << "var<workgroup> " << vid << " : array<";
PrintType(op->dtype, this->decl_stream);
this->decl_stream << ", " << constant_size << ">;\n";
} else if (storage_scope.rank == runtime::StorageRank::kLocal) {
// TODO(Charlie): These code would cause non-uniformity as it introduces
// variables in module scope rather than function scope; but it was included
// for some unknown reasons; kept for now. this->decl_stream <<
// "var<private> " << vid << " : array<"; PrintType(op->dtype,
// this->decl_stream); this->decl_stream << ", " << constant_size << ">;\n";
this->PrintIndent();
this->stream << "var " << vid << " : array<";
PrintType(op->dtype, this->stream);
this->stream << ", " << constant_size << ">;\n";
} else {
LOG(FATAL) << "WebGPU: Do not support storage scope: "
<< storage_scope.to_string();
}
this->PrintStmt(op->body);
}
void CodeGenTileLangWebGPU::VisitStmt_(const ForNode *op) {
std::string extent = PrintExpr(op->extent);
std::string vid = AllocVarID(op->loop_var.get());
ICHECK(is_zero(op->min));
PrintIndent();
stream << "for (var " << vid << " : ";
PrintType(op->loop_var.dtype(), stream);
stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
PrintIndent();
stream << "}\n";
}
void CodeGenTileLangWebGPU::VisitStmt_(const AssertStmtNode *op) {
// skip assert
PrintStmt(op->body);
}
void CodeGenTileLangWebGPU::VisitStmt_(const AllocateConstNode *op) {
LOG(FATAL) << "WebGPU: do not support alloc const";
}
void CodeGenTileLangWebGPU::VisitStmt_(const WhileNode *op) {
PrintIndent();
stream << "while (true) {\n";
int while_scope = BeginScope();
std::string cond = PrintExpr(op->condition);
PrintIndent();
stream << "if (!(" << cond << ")) { break; }\n";
PrintStmt(op->body);
this->EndScope(while_scope);
PrintIndent();
stream << "}\n";
}
//-------------------------------------------------
// WebGPUSourceModule to enable export
//-------------------------------------------------
class WebGPUSourceModuleNode final : public runtime::ModuleNode {
public:
explicit WebGPUSourceModuleNode(
std::unordered_map<std::string, std::string> smap,
std::unordered_map<std::string, runtime::FunctionInfo> fmap)
: smap_(smap), fmap_(fmap) {}
const char *type_key() const final { return "webgpu"; }
/*! \brief Get the property of the runtime module .*/
int GetPropertyMask() const final {
return runtime::ModulePropertyMask::kBinarySerializable;
}
PackedFunc GetFunction(const String &name,
const ObjectPtr<Object> &sptr_to_self) final {
LOG(FATAL) << "WebGPUSourceModule is not directly runnable, export and run "
"through tvmjs";
return PackedFunc(nullptr);
}
void SaveToBinary(dmlc::Stream *stream) final {
stream->Write(fmap_);
stream->Write(smap_);
}
String GetSource(const String &format) final {
if (format == "func_info") {
std::ostringstream stream;
dmlc::JSONWriter(&stream).Write(fmap_);
return stream.str();
} else {
std::ostringstream os;
for (auto kv : smap_) {
os << kv.second;
}
return os.str();
}
}
private:
// function shader code table.
std::unordered_map<std::string, std::string> smap_;
// function information table.
std::unordered_map<std::string, runtime::FunctionInfo> fmap_;
};
//-------------------------------------------------
// Build logic.
//-------------------------------------------------
runtime::Module BuildTileLangWebGPU(IRModule mod, Target target) {
mod = tir::transform::PointerValueTypeRewrite()(std::move(mod));
bool output_ssa = false;
bool skip_readonly_decl = false;
std::unordered_map<std::string, std::string> smap;
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
// narrow all i64 to i32
mod = tir::transform::ForceNarrowIndexToInt32()(std::move(mod));
for (auto kv : mod->functions) {
CodeGenTileLangWebGPU cg(target);
ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenTileLangWebGPU: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
<< "CodeGenTileLangWebGPU: expect calling_conv equals "
"CallingConv::kDeviceKernelLaunch";
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(global_symbol.defined()) << "CodeGenTileLangWebGPU: Expect PrimFunc "
"to have the global_symbol attribute";
std::string f_name = global_symbol.value();
cg.Init(output_ssa);
fmap[f_name] = cg.AddFunction(f, skip_readonly_decl);
std::string code = cg.Finish();
smap[f_name] = code;
}
auto n = make_object<WebGPUSourceModuleNode>(smap, fmap);
return runtime::Module(n);
}
TVM_REGISTER_GLOBAL("target.build.tilelang_webgpu")
.set_body_typed([](IRModule mod, Target target) {
return BuildTileLangWebGPU(mod, target);
});
} // namespace codegen
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_webgpu.h
* \brief Generate WebGPU shaders in WGSL.
*
* This module generates WGSL shading language.
* See https://www.w3.org/TR/WGSL/ for the language reference.
*/
#ifndef TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
#define TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
#include <tvm/target/codegen.h>
#include <string>
#include "target/source/codegen_c.h"
namespace tvm {
namespace codegen {
/*!
* \brief WebGPU code generator.
*
* Note WGSL have a different syntax from normal C.
* We only leverage the C for expression generation and
* write most of the language generations.
*/
class CodeGenTileLangWebGPU final : public CodeGenC {
public:
explicit CodeGenTileLangWebGPU(Target target);
// overrides
std::string Finish() final;
using CodeGenC::AddFunction;
runtime::FunctionInfo AddFunction(const PrimFunc &f,
bool skip_readonly_decl); // NOLINT(*)
void InitFuncState(const PrimFunc &f) final;
void PrintStorageSync(const CallNode *op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
// assignment printing
void PrintSSAAssign(const std::string &target, const std::string &src,
DataType type) final;
// overload visitor
void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const CallNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BufferLoadNode *op,
std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const CastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const SelectNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const IntImmNode *op, std::ostream &os) final; // NOLINT(*)
// stmt printing
void VisitStmt_(const LetStmtNode *op) final;
void VisitStmt_(const BufferStoreNode *op) final;
void VisitStmt_(const ForNode *op) final;
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AssertStmtNode *op) final;
void VisitStmt_(const AllocateConstNode *op) final;
void VisitStmt_(const WhileNode *op) final;
private:
/*!
* \brief Enforce value to be U32.
*/
static PrimExpr EnforceU32(PrimExpr value);
/*!
* \brief Storage type of bool values.
*/
DataType boolean_storage_type_{DataType::Int(8)};
// whether enable fp16
bool enable_fp16_{false};
/*! \brief the header stream for function label and enable directive if any,
* goes before any other declaration */
std::ostringstream header_stream;
Target target_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <math.h>
#include <stdbool.h>
// Not Implemented
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
// Not Implemented
......@@ -119,7 +119,7 @@ private:
const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = 128 / access_type.bits();
int max_vector_size = vector_load_bits_max_ / access_type.bits();
// so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
......@@ -159,7 +159,7 @@ private:
}
}
static const int vector_load_bits_max_ = 128;
const int vector_load_bits_max_ = 128;
const ForNode *inner_for_;
Map<Var, Range> iter_map_;
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2)
T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2)
for i, j, k in T.Parallel(block_M, block_N, block_K):
C_local[i, j] += A_shared[i, k] * B_shared[k, j]
T.copy(C_local, C[by * block_M, bx * block_N], coalesced_width=2)
return main
def assert_gemm_codegen(
M,
N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float",
):
func = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype)
print(func)
rt_mod, _ = tilelang.lower(func, target="webgpu")
src_code = rt_mod.imported_modules[0].get_source()
assert src_code is not None
def test_gemm_codegen():
assert_gemm_codegen(1024, 1024, 1024, 16, 16, 16)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -228,6 +228,8 @@ def lower(
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target)
elif target.kind.name == "llvm":
device_mod = tvm._ffi.get_global_func("target.build.llvm")(device_mod, target)
elif target.kind.name == "webgpu":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
else:
raise ValueError("Target is not supported")
......
......@@ -11,6 +11,7 @@ AVALIABLE_TARGETS = {
"auto",
"cuda",
"hip",
"webgpu",
"c", # represent c source backend
"llvm",
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment