"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "298eeffb2eeecf4b48c46e955b993dbf57f21142"
Commit c8fc0cbb authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Backend][WebGPU] Support WebGPU WGSL code generation (#86)

* bump version into v0.1.0

* [Enhancement] Add custom develop command for editable installs and update .gitignore

* [Documentation] Update README to include system dependencies installation instructions

* [Build] Update setup.py to support library file copying for both release and develop modes

* [Build] Refactor library file copying logic in setup.py

* [Documentation] Remove unnecessary install section header in Installation.md

* [Build] Add tox configuration and local distribution script for multi-Python version support

* [Build] Improve git submodule update function with better error handling

* [Build] Update LLVM configuration path in ROCm installation script

* [Build] Add .tox/ to .gitignore for tox testing environment

* [Build] Add support for TVM prebuild path configuration in CMakeLists.txt

* [Cleanup] Remove unused TVM runtime error codes header

* [Cleanup] Fix TVM grid constant type reference in CUDA module

* [Cleanup] Remove unused customized_code function from IR module

* [Feature] Add TileLang thread synchronization and storage access analysis passes

* [Build] Reorder DLL search path directories for more flexible library loading

* [Refactor] Improve thread synchronization and library path handling

- Rename ThreadSync and TileLangThreadSync functions in C++ code
- Update Python docstring for ThreadSync with more detailed description
- Reorder library path detection in tilelang environment setup
- Minor comment and code cleanup in CUDA and warp specialization modules

* [Refactor] Improve thread synchronization code style and formatting

- Standardize pointer type spacing in storage_access.h and storage_access.cc
- Update whitespace and indentation in thread_storage_sync.cc
- Reorder include statements in thread_partial_sync.cc
- Minor code formatting improvements across thread synchronization files

* [Refactor] Fix global function registration for ThreadSync

- Correct global function registration to use ThreadSync instead of TileLangThreadSync
- Update TVM global registration to match recent refactoring efforts

* [Refactor] Simplify ThreadSync global function registration

- Remove unnecessary whitespace in global function registration
- Compact the TVM global registration line for ThreadSync

* [Feature] Add WebGPU code generation support in TileLang

- Implement WebGPU code generator (codegen_webgpu.cc and codegen_webgpu.h)
- Add WebGPU target support in lower.py and target.py
- Update CMakeLists.txt to include WebGPU codegen source files
- Introduce WebGPU-specific code generation for WGSL shader language

* [Refactor] Improve WebGPU code generation formatting and readability

- Enhance code formatting in codegen_webgpu.cc and codegen_webgpu.h
- Standardize pointer type spacing and indentation
- Improve line breaks and reduce line length for better readability
- Minor code style improvements in WebGPU code generation

* [Test] Add WebGPU matrix multiplication code generation test

- Implement test_webgpu_codegen.py for WebGPU matrix multiplication
- Add assert_gemm_codegen function to validate WebGPU code generation
- Include basic matrix multiplication kernel test case

* Update README with WebGPU codegen support announcement
parent ec84188f
......@@ -110,6 +110,8 @@ tilelang_file_glob(GLOB TILE_LANG_SRCS
src/target/utils.cc
src/target/codegen_cpp.cc
src/target/rt_mod_cpp.cc
# webgpu doesn't have system dependency
src/target/codegen_webgpu.cc
)
# Include CUDA source files if CUDA is enabled
......
......@@ -11,6 +11,7 @@ Tile Language (**tile-lang**) is a concise domain-specific language designed to
<img src=./images/MatmulExample.png />
## Latest News
- 02/15/2025 ✨: Added WebGPU codegen support, see [Pull Request #86](https://github.com/tile-ai/tilelang/pull/86)!
- 02/12/2025 ✨: Excited to announce the release of [v0.1.0](https://github.com/tile-ai/tilelang/releases/tag/v0.1.0)!
- 02/10/2025 🚀: Added debug tools for TileLang—`T.print` for printing variables/buffers ([docs](https://tilelang.tile-ai.cn/tutorials/debug_tools_for_tilelang.html)) and a memory layout plotter ([examples/plot_layout](./examples/plot_layout)).
- 01/20/2025 ✨: We are excited to announce that tile-lang, a dsl for high performance AI workloads, is now open source and available to the public!
......
This diff is collapsed.
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file codegen_webgpu.h
* \brief Generate WebGPU shaders in WGSL.
*
* This module generates WGSL shading language.
* See https://www.w3.org/TR/WGSL/ for the language reference.
*/
#ifndef TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
#define TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
#include <tvm/target/codegen.h>
#include <string>
#include "target/source/codegen_c.h"
namespace tvm {
namespace codegen {
/*!
* \brief WebGPU code generator.
*
* Note WGSL have a different syntax from normal C.
* We only leverage the C for expression generation and
* write most of the language generations.
*/
class CodeGenTileLangWebGPU final : public CodeGenC {
public:
explicit CodeGenTileLangWebGPU(Target target);
// overrides
std::string Finish() final;
using CodeGenC::AddFunction;
runtime::FunctionInfo AddFunction(const PrimFunc &f,
bool skip_readonly_decl); // NOLINT(*)
void InitFuncState(const PrimFunc &f) final;
void PrintStorageSync(const CallNode *op) final; // NOLINT(*)
void PrintType(DataType t, std::ostream &os) final; // NOLINT(*)
void BindThreadIndex(const IterVar &iv) final; // NOLINT(*)
// assignment printing
void PrintSSAAssign(const std::string &target, const std::string &src,
DataType type) final;
// overload visitor
void VisitExpr_(const BroadcastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const CallNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const BufferLoadNode *op,
std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const CastNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const SelectNode *op, std::ostream &os) override; // NOLINT(*)
void VisitExpr_(const FloatImmNode *op, std::ostream &os) final; // NOLINT(*)
void VisitExpr_(const IntImmNode *op, std::ostream &os) final; // NOLINT(*)
// stmt printing
void VisitStmt_(const LetStmtNode *op) final;
void VisitStmt_(const BufferStoreNode *op) final;
void VisitStmt_(const ForNode *op) final;
void VisitStmt_(const AllocateNode *op) final;
void VisitStmt_(const AssertStmtNode *op) final;
void VisitStmt_(const AllocateConstNode *op) final;
void VisitStmt_(const WhileNode *op) final;
private:
/*!
* \brief Enforce value to be U32.
*/
static PrimExpr EnforceU32(PrimExpr value);
/*!
* \brief Storage type of bool values.
*/
DataType boolean_storage_type_{DataType::Int(8)};
// whether enable fp16
bool enable_fp16_{false};
/*! \brief the header stream for function label and enable directive if any,
* goes before any other declaration */
std::ostringstream header_stream;
Target target_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_SOURCE_CODEGEN_WEBGPU_H_
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
#include <math.h>
#include <stdbool.h>
// Not Implemented
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
#pragma once
// Not Implemented
......@@ -119,7 +119,7 @@ private:
const DataType &access_type = buffer->dtype;
// i // 2, i % 8 can also be vectorized as factor 16
int max_vector_size = 128 / access_type.bits();
int max_vector_size = vector_load_bits_max_ / access_type.bits();
// so we should disable this GCD optimization
max_vector_size = arith::ZeroAwareGCD(max_vector_size, extent_ptr->value);
......@@ -159,7 +159,7 @@ private:
}
}
static const int vector_load_bits_max_ = 128;
const int vector_load_bits_max_ = 128;
const ForNode *inner_for_;
Map<Var, Range> iter_map_;
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import tilelang
from tilelang import tvm as tvm
import tilelang.testing
import tilelang.language as T
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
@T.prim_func
def main(
A: T.Buffer((M, K), dtype),
B: T.Buffer((K, N), dtype),
C: T.Buffer((M, N), dtype),
):
# Initialize Kernel Context
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_K, block_N), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=0):
T.copy(A[by * block_M, ko * block_K], A_shared, coalesced_width=2)
T.copy(B[ko * block_K, bx * block_N], B_shared, coalesced_width=2)
for i, j, k in T.Parallel(block_M, block_N, block_K):
C_local[i, j] += A_shared[i, k] * B_shared[k, j]
T.copy(C_local, C[by * block_M, bx * block_N], coalesced_width=2)
return main
def assert_gemm_codegen(
M,
N,
K,
block_M,
block_N,
block_K,
dtype="float16",
accum_dtype="float",
):
func = matmul(M, N, K, block_M, block_N, block_K, dtype=dtype, accum_dtype=accum_dtype)
print(func)
rt_mod, _ = tilelang.lower(func, target="webgpu")
src_code = rt_mod.imported_modules[0].get_source()
assert src_code is not None
def test_gemm_codegen():
assert_gemm_codegen(1024, 1024, 1024, 16, 16, 16)
if __name__ == "__main__":
tilelang.testing.main()
......@@ -228,6 +228,8 @@ def lower(
device_mod = tvm._ffi.get_global_func("target.build.tilelang_cpp")(device_mod, target)
elif target.kind.name == "llvm":
device_mod = tvm._ffi.get_global_func("target.build.llvm")(device_mod, target)
elif target.kind.name == "webgpu":
device_mod = tvm._ffi.get_global_func("target.build.tilelang_webgpu")(device_mod, target)
else:
raise ValueError("Target is not supported")
......
......@@ -11,6 +11,7 @@ AVALIABLE_TARGETS = {
"auto",
"cuda",
"hip",
"webgpu",
"c", # represent c source backend
"llvm",
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment