Commit c918cd4f authored by letaoqin's avatar letaoqin
Browse files

start

parent 126ce85a
set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe_general")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding ${TILE_EXAPMLE_FUSED_MOE}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp)
target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS})
set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS})
# fused-moe
Implementing the fused-moe block operator using ck-tile. This is a scatter/gather-group-gemm based solution, similiar to that of [vllm moe](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), but we introduce more kernel fusion to boost performance
![](misc/moe-0.png)
The benifit of this fused-moe:
* 1.5~2x perf boost compared with current vllm solution
* zero workspace to reduce memory footprint
* much less kernel instance, easy to maintain
# Implementation and feature support
## moe-sorting
this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic)
## moe-gemm
`moe-gemm` is a group-gemm based back-to-back gemm, where the row-id of input token comes from another buffer. Naive understanding of fused-moe is from token-by-token view as below picture:
![](misc/moe-1.png)
After `moe-sorting`, we can view this algorithm as expert-by-expert, as below:
![](misc/moe-2.png)
## optimization
summary of the key design of this fused-moe operator:
* fuse 2 group-gemm + activation + `topk-weight` multiply into single kernel, using atomic for 2nd gemm accumualation
* fuse buffer-zeroing in `moe-sorgin`, user no longer need call extra torch.zero() for the out buffer
* fused scatter-gather for row index(same as vllm)
* pre-shuffle B matric(weight) to maximize memory throughput. input(activation) keep original layout `[batch, hidden]`.
* extrem optimized pipeline using block-inline-asm(we call it `micro-kernel` or `uk`), while not breaking the *composable* design of ck
##
```
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
```
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include <string>
// this is only a convenient structure for creating an example
// this is not part of the host API
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig;
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{
using ADataType = ck_tile::bf16_t;
using GDataType = ck_tile::bf16_t;
using DDataType = ck_tile::bf16_t;
using AccDataType = float;
using ODataType = ck_tile::bf16_t;
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
using IndexDataType = ck_tile::index_t;
};
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, ST, SW, SQ, KW>
{
using ADataType = ck_tile::fp16_t;
using GDataType = ck_tile::fp16_t;
using DDataType = ck_tile::fp16_t;
using AccDataType = float;
using ODataType = ck_tile::fp16_t;
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
using IndexDataType = ck_tile::index_t;
};
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{
using ADataType = ck_tile::int8_t;
using GDataType = ck_tile::int8_t;
using DDataType = ck_tile::int8_t;
using AccDataType = int32_t;
using ODataType = ck_tile::bf16_t;
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
using IndexDataType = ck_tile::index_t;
};
// runtime args
struct fused_moegemm_args : public ck_tile::FusedMoeGemmHostArgs
{
};
// This is the public API, will be generated by script
struct fused_moegemm_traits
{
std::string prec_i; // input precision
std::string prec_w; // weight precision
std::string prec_o; // output precision
std::string prec_st; // token scale data type
std::string prec_sw; // weight scale data type
std::string prec_sq; // smooth quant scale
std::string prec_kw; // topk-weight data type
int block_m;
int gate_only;
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
float fused_moegemm(fused_moegemm_traits, fused_moegemm_args, const ck_tile::stream_config&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template <typename Traits_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a);
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile::stream_config& s)
{
// clang-format off
float r = -1;
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
r = fused_moegemm_<t_>(s, a);
}
// clang-format on
return r;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include <iostream>
template <ck_tile::index_t... Is>
using S = ck_tile::sequence<Is...>;
// do not the define of this tepmlate function inside the _api.cpp, otherwise will block make -j
template <typename Ts_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
{
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0,
typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0>;
using f_problem =
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
typename Ts_::GDataType,
typename Ts_::DDataType,
typename Ts_::AccDataType,
typename Ts_::ODataType,
typename Ts_::AScaleDataType,
typename Ts_::GScaleDataType,
typename Ts_::DScaleDataType,
typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType,
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded
f_shape,
f_traits>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
const dim3 grids = f_kernel::GridSize(a);
constexpr dim3 blocks = f_kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
static int printed = 0;
auto kargs = f_kernel::MakeKargs(a);
if(s.log_level_ > 0 && printed == 0)
{
std::cout << ", " << f_kernel::GetName() << std::flush;
printed = 1;
}
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(f_kernel{}, grids, blocks, 0, kargs));
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename I,
typename W,
typename O,
typename ST,
typename SW,
typename SQ,
typename KW,
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
typename WarpPerBlock_,
typename WarpTile_, // seq<*,*,*>, used to select mfma
ck_tile::index_t GateOnly_ = 0,
ck_tile::index_t FusedQuant_ = 0>
struct fmoe_ // traits, ugly name, only used for internal
{
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = ck_tile::remove_cvref_t<typename TypeConfig::ADataType>;
using GDataType = ck_tile::remove_cvref_t<typename TypeConfig::GDataType>;
using DDataType = ck_tile::remove_cvref_t<typename TypeConfig::DDataType>;
using AccDataType = ck_tile::remove_cvref_t<typename TypeConfig::AccDataType>;
using ODataType = ck_tile::remove_cvref_t<typename TypeConfig::ODataType>;
using AScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::AScaleDataType>;
using GScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::GScaleDataType>;
using DScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::DScaleDataType>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<typename TypeConfig::YSmoothScaleDataType>;
using TopkWeightDataType = ck_tile::remove_cvref_t<typename TypeConfig::TopkWeightDataType>;
using IndexDataType = ck_tile::remove_cvref_t<typename TypeConfig::IndexDataType>;
static constexpr ck_tile::index_t BT_ = BlockTIle_::at(ck_tile::number<0>{}); // block token
static constexpr ck_tile::index_t BI_ =
BlockTIle_::at(ck_tile::number<1>{}); // block intermediate
static constexpr ck_tile::index_t BH_ = BlockTIle_::at(ck_tile::number<2>{}); // block hidden
static constexpr ck_tile::index_t BD_ = BlockTIle_::at(ck_tile::number<3>{}); // block down
using BlockTile_0 = ck_tile::sequence<BT_, BI_, BH_>;
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
static constexpr ck_tile::index_t GateOnly = GateOnly_;
static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "fused_moegemm.hpp"
#include "fused_moegemm_api_traits.hpp"
#include "fused_moegemm_api_internal.hpp"
// clang-format off
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on
#include <algorithm>
#include <cstring>
#include <unordered_set>
#include <vector>
#include <set>
#include "ck_tile/host.hpp"
#include "fused_moegemm.hpp"
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
// mfma_type, 0:32x32, 1:16x16
// TODO: padding?
template <typename T>
auto shuffle_moe_weight(const ck_tile::HostTensor<T>& t, std::string mfma_dtype, int mfma_type = 0)
{
assert(t.get_lengths().size() == 3);
int b_ = t.get_lengths()[0];
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[2];
if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 16, 2, 8});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 32, 4, 8});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 0)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 32, 32, k_ / 32, 2, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
else if((mfma_dtype == "int8" || mfma_dtype == "fp8") && mfma_type == 1)
{
ck_tile::HostTensor<T> t_view({b_, n_ / 16, 16, k_ / 64, 4, 16});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 1, 3, 4, 2, 5});
}
return t;
}
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("t", "128", "num input tokens")
.insert("e", "32", "num of experts")
.insert("k", "5", "topk")
.insert("h", "8192", "hidden_size of this model")
.insert("i", "8192", "intermediate_size between 2 gemms of FFN")
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
.insert("bm", "32", "blocking factor for sorted tokens")
.insert("tp", "1", "tensor parallel size")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec_i", "bf16", "input precision")
.insert("prec_w", "bf16", "weight precision")
.insert("prec_o", "bf16", "output precision")
.insert("prec_st", "auto", "token scale data type. auto will set to fp32")
.insert("prec_sw", "auto", "weight scale data type. auto will set to fp32")
.insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32")
.insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert(
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
.insert("balance",
"0",
"if set to 1, will try balance the expert in topk-ids(convenient for testing)")
.insert("init",
"2",
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized"
"normalized(slow)")
.insert("seed", "11939", "seed used to do random")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type,
// SQ:smooth-quant-type, KW:topk-weight-type
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t tokens = arg_parser.get_int("t");
ck_tile::index_t experts = arg_parser.get_int("e");
ck_tile::index_t topk = arg_parser.get_int("k");
ck_tile::index_t hidden_size = arg_parser.get_int("h");
ck_tile::index_t intermediate_size = arg_parser.get_int("i");
ck_tile::index_t stride = arg_parser.get_int("stride");
ck_tile::index_t block_m = arg_parser.get_int("bm");
if(stride < 0)
stride = hidden_size;
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_w = arg_parser.get_str("prec_w");
std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_st = arg_parser.get_str("prec_st");
std::string prec_sw = arg_parser.get_str("prec_sw");
std::string prec_sq = arg_parser.get_str("prec_sq");
std::string prec_kw = arg_parser.get_str("prec_kw");
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int fused_quant = arg_parser.get_int("fquant");
int gate_only = arg_parser.get_int("gate_only");
int balance = arg_parser.get_int("balance");
int tp = arg_parser.get_int("tp");
int init = arg_parser.get_int("init");
uint32_t seed = arg_parser.get_uint32("seed");
// w0 (Gate+Up or Gate only, N size)
ck_tile::index_t shared_intermediate_size_0 = intermediate_size * (gate_only ? 1 : 2) / tp;
// w1 (Down, N size)
ck_tile::index_t shared_intermediate_size_1 = intermediate_size / tp;
auto prec_str = [&]() {
auto base_str = prec_i;
if(prec_i != prec_w)
base_str += "x" + prec_w;
if(prec_i != prec_o)
base_str += "=" + prec_o;
if(fused_quant != 0)
{
base_str += std::string("(") + prec_st + "|" + prec_sw + "|" + prec_sq + ")";
}
return base_str;
}();
auto api_str = [&]() {
return std::string("moeg");
}();
auto stride_str = [&]() {
if(stride == hidden_size)
return std::string("");
else
return std::string(", st:") + std::to_string(stride);
}();
std::cout << "[" << api_str << "|" << prec_str << "]"
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType;
using GDataType = typename TypeConfig::GDataType;
using DDataType = typename TypeConfig::DDataType;
using AccDataType = typename TypeConfig::AccDataType;
using ODataType = typename TypeConfig::ODataType;
using AScaleDataType = typename TypeConfig::AScaleDataType;
using GScaleDataType = typename TypeConfig::GScaleDataType;
using DScaleDataType = typename TypeConfig::DScaleDataType;
using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType;
using TopkWeightDataType = typename TypeConfig::TopkWeightDataType;
using IndexDataType = typename TypeConfig::IndexDataType;
// host verify
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size});
ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size_0});
ck_tile::HostTensor<DScaleDataType> sd_host({shared_intermediate_size_1});
ck_tile::HostTensor<YSmoothScaleDataType> sy_host({shared_intermediate_size_1}); // smooth-quant
ck_tile::HostTensor<IndexDataType> topk_ids_host({tokens, topk}); // to be sort
ck_tile::HostTensor<TopkWeightDataType> topk_weight_host({tokens, topk}); // to be sort
int max_num_tokens_padded = topk * tokens + experts * block_m - topk;
ck_tile::HostTensor<IndexDataType> sorted_token_ids_host({max_num_tokens_padded});
ck_tile::HostTensor<TopkWeightDataType> sorted_weight_host({max_num_tokens_padded});
ck_tile::HostTensor<IndexDataType> sorted_expert_ids_host(
{(max_num_tokens_padded + block_m - 1) / block_m});
ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
if(init == 0)
{
ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host);
ck_tile::FillStepRange<GDataType>{-.5f, .5f, 0.01f}(g_host);
ck_tile::FillStepRange<DDataType, false>{.5f, -.5f, -0.01f}(d_host);
ck_tile::FillStepRange<AScaleDataType>{0.f, 1.f, 0.01f}(sa_host);
ck_tile::FillStepRange<GScaleDataType>{0.f, 1.f, 0.01f}(sg_host);
ck_tile::FillStepRange<DScaleDataType>{0.f, 1.f, 0.01f}(sd_host);
ck_tile::FillStepRange<YSmoothScaleDataType>{0.f, 1.f, 0.01f}(sy_host);
ck_tile::FillStepRange<TopkWeightDataType>{-.5f, .5f, 0.01f}(topk_weight_host);
}
else if(init == 1)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed, true}(a_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed, true}(g_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed, true}(d_host);
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed, true}(sa_host);
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f, seed, true}(sy_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f, seed, true}(
topk_weight_host);
}
else if(init == 2)
{
ck_tile::FillNormalDistribution<ADataType>{0.f, 1.f, seed, true}(a_host);
ck_tile::FillNormalDistribution<GDataType>{0.f, 1.f, seed, true}(g_host);
ck_tile::FillNormalDistribution<DDataType>{0.f, 1.f, seed, true}(d_host);
ck_tile::FillNormalDistribution<AScaleDataType>{0.f, 1.f, seed, true}(sa_host);
ck_tile::FillNormalDistribution<GScaleDataType>{0.f, 1.f, seed, true}(sg_host);
ck_tile::FillNormalDistribution<DScaleDataType>{0.f, 1.f, seed, true}(sd_host);
ck_tile::FillNormalDistribution<YSmoothScaleDataType>{0.f, 1.f, seed, true}(sy_host);
ck_tile::FillNormalDistribution<TopkWeightDataType>{0.f, 1.f, seed, true}(topk_weight_host);
}
// permute weight
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
// do moe sorting
if(balance)
{
int e_cnt = 0;
for(int i = 0; i < static_cast<int>(topk_ids_host.mData.size()); i++)
{
topk_ids_host.mData[i] = e_cnt;
e_cnt++;
if(e_cnt >= experts)
e_cnt = 0;
}
}
else
{
topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
}
// leave it here for future debug purpose
#if 0
a_host.loadtxt("../../ater/input_torch.txt");
topk_ids_host.loadtxt("../../ater/topk_ids_torch.txt", "int");
// topk_ids_host.savetxt("topk_ids_2.txt");
topk_weight_host.loadtxt("../../ater/topk_weights_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
g_host.loadtxt("../../ater/w1_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
d_host.loadtxt("../../ater/w2_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
#endif
#if 0
std::cout << "sorted_token_ids_host:" << sorted_token_ids_host << std::endl;
std::cout << "num_sorted_tiles_host:" << num_sorted_tiles_host << std::endl;
std::cout << "sorted_expert_ids_host:" << sorted_expert_ids_host << std::endl;
std::cout << "topk_weight_host:" << topk_weight_host << std::endl;
std::cout << "sorted_weight_host:" << sorted_weight_host << std::endl;
#endif
auto cal_tflops = [&](auto ms) {
double flop_gemm_0 =
2 * static_cast<double>(tokens) * topk * shared_intermediate_size_0 * hidden_size;
double flop_gemm_1 =
2 * static_cast<double>(tokens) * topk * shared_intermediate_size_1 * hidden_size;
return (flop_gemm_0 + flop_gemm_1) / (static_cast<double>(ms) * 1e-3) / 1e12;
};
// TODO: this method we use expert-by-expert view, just for reference
auto cal_tbps = [&](auto ms) {
double token_bytes =
static_cast<double>(tokens) * topk / experts * hidden_size * sizeof(ADataType);
double w0_bytes = static_cast<double>(shared_intermediate_size_0) * experts * hidden_size *
sizeof(GDataType);
double w1_bytes = static_cast<double>(shared_intermediate_size_1) * experts * hidden_size *
sizeof(DDataType);
double o_bytes =
static_cast<double>(tokens) * topk / experts * hidden_size * sizeof(ODataType);
double topk_weights_bytes = static_cast<double>(tokens) * topk * sizeof(TopkWeightDataType);
// ignore index, they are too small
return (token_bytes + w0_bytes + w1_bytes + o_bytes + topk_weights_bytes) /
(static_cast<double>(ms) * 1e-3) / 1e12;
};
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host,
topk_weight_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host.mData[0],
experts,
block_m);
// done, preparing GPU buffer
ck_tile::DeviceMem a_buf(a_host);
ck_tile::DeviceMem g_perm_buf(g_perm_host);
ck_tile::DeviceMem d_perm_buf(d_perm_host);
ck_tile::DeviceMem sa_buf(sa_host);
ck_tile::DeviceMem sg_buf(sg_host);
ck_tile::DeviceMem sd_buf(sd_host);
ck_tile::DeviceMem sy_buf(sy_host);
ck_tile::DeviceMem o_buf(o_host);
// manually clear output buffer for atomic
o_buf.SetZero();
//
ck_tile::DeviceMem sorted_token_ids_buf(sorted_token_ids_host);
ck_tile::DeviceMem sorted_weight_buf(sorted_weight_host);
ck_tile::DeviceMem sorted_expert_ids_buf(sorted_expert_ids_host);
ck_tile::DeviceMem num_sorted_tiles_buf(num_sorted_tiles_host);
fused_moegemm_traits traits{prec_i,
prec_w,
prec_o,
prec_st,
prec_sw,
prec_sq,
prec_kw,
block_m,
gate_only,
fused_quant};
fused_moegemm_args args{a_buf.GetDeviceBuffer(),
fused_quant != 0 ? sa_buf.GetDeviceBuffer() : nullptr,
g_perm_buf.GetDeviceBuffer(),
d_perm_buf.GetDeviceBuffer(),
fused_quant != 0 ? sg_buf.GetDeviceBuffer() : nullptr,
fused_quant != 0 ? sd_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? sy_buf.GetDeviceBuffer() : nullptr,
o_buf.GetDeviceBuffer(),
sorted_token_ids_buf.GetDeviceBuffer(),
sorted_weight_buf.GetDeviceBuffer(),
sorted_expert_ids_buf.GetDeviceBuffer(),
num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size,
shared_intermediate_size_0,
tokens,
experts,
topk,
stride};
float ave_time = fused_moegemm(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
if(ave_time < 0)
{
std::cout << " not supported!" << std::endl << std::flush;
return false;
}
// float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << cal_tflops(ave_time) << " tflops, "
<< cal_tbps(ave_time) << " TB/s" << std::flush;
bool pass = true;
if(do_validation)
{
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
a_host,
g_host,
d_host,
sa_host,
sg_host,
sd_host,
sy_host,
o_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host,
topk_ids_host,
block_m,
tokens,
experts,
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float");
auto [rtol, atol] = get_elimit<ADataType>();
pass &= ck_tile::check_err(
o_dev, o_host, std::string("OUT Error: Incorrect results!"), rtol, atol);
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
std::cout << std::flush << std::endl;
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_w = arg_parser.get_str("prec_w");
std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_st = arg_parser.get_str("prec_st");
std::string prec_sw = arg_parser.get_str("prec_sw");
std::string prec_sq = arg_parser.get_str("prec_sq");
std::string prec_kw = arg_parser.get_str("prec_kw");
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
// no dynamic quant case
if(prec_i == "bf16" && prec_w == "bf16" && prec_o == "bf16" && prec_kw == "fp32")
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float>(
arg_parser)
? 0
: -2;
}
else if(prec_i == "fp16" && prec_w == "fp16" && prec_o == "fp16" && prec_kw == "fp32")
{
return run<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float>(
arg_parser)
? 0
: -2;
}
return -3;
}
......@@ -16,3 +16,4 @@ add_subdirectory(13_moe_sorting)
add_subdirectory(14_moe_smoothquant)
add_subdirectory(15_fused_moe)
add_subdirectory(16_batched_gemm)
add_subdirectory(17_fused_moe_general)
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
......@@ -11,6 +12,8 @@
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_uk.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/core/tensor/tensor_view.hpp"
#include <string>
#include <type_traits>
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
//
// [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// we generate original rol/col id as
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
// let x be one element of above, we can get:
// tpok_row_id(token_id) = x % num_tokens(5)
// tpok_col_id(expert_Id) = x / num_tokens
// topk_row_id/col_id can be used to access original topk_ids/topk_weight
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// we can get permuted_rc_ids:
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
//
//
// clang-format on
//
namespace ck_tile {
// m: num_tokens (or token*input-batch)
// k: intermediate_size
// n: intermediate_size used between 2 FC (TP slice this)
// e: num expert
// if doing pre-shuffle
// nr : n / Block_Nr
// kr : k / Block_Kr
// w : fattened 1d wave buffer
// struct FusedMoeGemmHostArgs
// {
// const void* a_ptr; // [m, k], input token
// const void* a_scale_ptr; // [m, 1], token scale
// const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
// const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
// const void* g_scale_ptr; // [e, 1, n], gate(up) scale
// const void* d_scale_ptr; // [e, 1, k], down scale
// const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
// void* o_ptr; // [m, k], output token
// const void* sorted_token_ids_ptr; // [max_num_tokens_padded]
// const void* sorted_weight_ptr; // [max_num_tokens_padded]
// const void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
// const void* num_sorted_tiles_ptr; // [1]
// index_t hidden_size; // k
// index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
// index_t num_tokens; // input number of tokens for current iteration
// index_t num_experts; // number of groups
// index_t topk; // need this?
// index_t stride_token; // for input/output, stride for each row, should >= hidden_size
// };
// This is scatter/gather b2b group-gemm
template <typename Partitioner_, typename Pipeline_, typename Epilogue_>
struct FusedMoeGemmGlKernel
{
using Partitioner = remove_cvref_t<Partitioner_>;
using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0);
using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape
static constexpr index_t BlockSize_ = BlockShape::BlockSize;
using ADataType = typename Pipeline::Problem::ADataType;
using GDataType = typename Pipeline::Problem::GDataType;
using DDataType = typename Pipeline::Problem::DDataType;
using AccDataType = typename Pipeline::Problem::AccDataType;
using ODataType = typename Pipeline::Problem::ODataType;
using AScaleDataType = typename Pipeline::Problem::AScaleDataType;
using GScaleDataType = typename Pipeline::Problem::GScaleDataType;
using DScaleDataType = typename Pipeline::Problem::DScaleDataType;
using YSmoothScaleDataType = typename Pipeline::Problem::YSmoothScaleDataType;
using TopkWeightDataType = typename Pipeline::Problem::TopkWeightDataType;
using IndexDataType = typename Pipeline::Problem::IndexDataType;
using YDataType = typename Pipeline::Problem::YDataType;
using Traits = typename Pipeline::Problem::Traits;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<int8_t> { static constexpr const char * name = "int8"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using S_ = BlockShape;
auto prec_str = [&] () {
std::string base_str = _SS_(t2s<ADataType>::name);
if (!std::is_same_v<ADataType, GDataType>) {
base_str += _SS_("_") + _SS_(t2s<GDataType>::name);
}
return base_str;
}();
return _SS_("fused_moe_") + _SS_(prec_str) + "_" +
_TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
_TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
_TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name);
#undef _SS_
#undef _TS_
// clang-format on
}
struct FusedMoeGemmKargs
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token
const void* sorted_token_ids_ptr;
const void* sorted_weight_ptr;
const void* sorted_expert_ids_ptr;
const void* num_sorted_tiles_ptr;
index_t hidden_size; // k
index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
index_t topk; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
index_t max_num_tokens_padded; // size of sorted_token_ids_ptr
};
// TODO: switch karg based on
using Kargs = FusedMoeGemmKargs;
using Hargs = FusedMoeGemmHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
// TODO: hargs/kargs not guranteed to be the same
return bit_cast<Kargs>(hargs);
}
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{
constexpr index_t block_m = BlockShape::Block_M0;
int max_num_tokens_padded =
hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index(block_m and
// block_n)
const auto [sorted_tile_id, intermediate_tile_id] =
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles)
return;
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along intermediate_size
// index_t hidden_idx = __builtin_amdgcn_readfirstlane(intermediate_tile_id *
// BlockShape::Block_N0);
index_t idx_m0 = __builtin_amdgcn_readfirstlane(sorted_tile_id * BlockShape::Block_M0);
index_t idx_n0 = __builtin_amdgcn_readfirstlane(sorted_tile_id * BlockShape::Block_N0);
// const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
// if(threadIdx.x == 200 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0){
// printf("\n*************a_coord[0]: %d, a_coord[1]: %d size: %d \n",
// a_coord[number<0>{}], a_coord[number<1>{}], a_coord.size());
// }
// const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id *
// BlockShape::Block_M0; //not block pos?
const auto sorted_token_id = sorted_tile_id * BlockShape::Block_M0; // start block_m
// position
// index_t token_id =
// reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto topk_weight =
reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
const index_t* sorted_token_ids_ptr =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr);
const auto a_window = [&]() {
// A is already pre-padded in previous kernel
const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
const auto a_view_ = make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
number<Pipeline::kAlignmentA>{},
number<1>{});
// gather is here use indexing transform
const auto a_gather_view_ = transform_tensor_view(
a_view_,
make_tuple(make_indexing_transform_with_adaptor(
kargs.max_num_tokens_padded,
indexing_adaptor<index_t>{sorted_token_ids_ptr}),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto a_window_ = make_tile_window(
a_gather_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{idx_m0, 0});
return a_window_;
}();
const auto g_window = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(kargs.intermediate_size, kargs.hidden_size),
make_tuple(kargs.hidden_size, 1),
number<Pipeline::kAlignmentG>{},
number<1>{});
const auto g_window_ = make_tile_window(
g_view_,
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
{idx_n0, 0});
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// for(int i = 0; i < 16; i++)
// {
// printf("in G index is %d , value is: %f\n",
// i,
// ck_tile::type_convert<float>(g_ptr[i]));
// }
// }
return g_window_;
}();
const auto d_window = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1;
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(kargs.hidden_size, kargs.intermediate_size),
make_tuple(kargs.intermediate_size, 1),
number<Pipeline::kAlignmentD>{},
number<1>{});
const auto d_window_ = make_tile_window(
d_view_,
make_tuple(number<BlockShape::Block_N1>{}, number<BlockShape::Block_K1>{}),
{0, idx_n0});
return d_window_;
}();
auto o_window = [&]() {
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
auto o_view_ = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
o_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
number<Pipeline::kAlignmentO>{},
number<1>{});
// gather is here
auto o_scatter_view_ = transform_tensor_view(
o_view_,
make_tuple(make_indexing_transform_with_adaptor(
kargs.max_num_tokens_padded,
indexing_adaptor<index_t>{sorted_token_ids_ptr}),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto o_window_ = make_tile_window(
o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<BlockShape::Block_N1>{}),
{idx_m0, 0});
return o_window_;
}();
// do compute yeah
Pipeline{}(a_window,
g_window,
d_window,
o_window,
topk_weight,
smem,
kargs.hidden_size,
kargs.intermediate_size);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp"
namespace ck_tile {
/*
This pipeline deal with a gemm(actually 2 gemm) with one very small(token), one very big(weight)
we need to design the pipeline such that all waves along gemm-N dim (gemm-m only 1 wave)
<----- gemm-N ------>
+----+----+----+----+
| w0 | w1 | w2 | w3 | gemm-m
+----+----+----+----+
*/
template <typename Problem_, typename Policy_ = FusedMoeGemmPipelineGeneralPolicy>
struct FusedMoeGemmPipeline_General
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using BlockShape = typename Problem::BlockShape; // this is FusedMoeGemmShape
using ADataType = typename Problem::ADataType;
using GDataType = typename Problem::GDataType;
using DDataType = typename Problem::DDataType;
using AccDataType = typename Problem::AccDataType;
using ODataType = typename Problem::ODataType;
using AScaleDataType = typename Problem::AScaleDataType;
using GScaleDataType = typename Problem::GScaleDataType;
using DScaleDataType = typename Problem::DScaleDataType;
using YSmoothScaleDataType = typename Problem::YSmoothScaleDataType;
using TopkWeightDataType = typename Problem::TopkWeightDataType;
using IndexDataType = typename Problem::IndexDataType;
using YDataType = typename Problem::YDataType;
using Traits = typename Problem::Traits;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
static constexpr index_t kAlignmentA = Policy::template GetAlignment_A<Problem>();
static constexpr index_t kAlignmentG = Policy::template GetAlignment_G<Problem>();
static constexpr index_t kAlignmentD = Policy::template GetAlignment_D<Problem>();
static constexpr index_t kAlignmentO = Policy::template GetAlignment_O<Problem>();
static constexpr index_t SLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::SLD_A);
static constexpr index_t GLD_A = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_A);
static constexpr index_t GLD_B = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GLD_B);
static constexpr index_t GST_O = static_cast<index_t>(FusedMoeGemmPipelineSequencerEnum::GST_O);
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
// minimize occupancy
return 2;
}
}();
static constexpr const char* name = "flatmm_gl";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// matrix a or tokens smem
constexpr index_t smem_mat_a =
BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType);
// shuffle C matrix
constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_mat_a, smem_bridge);
// return Policy::template GetSmemSize<Problem>();
}
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetACoord()
{
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
template <typename T>
CK_TILE_HOST_DEVICE static void PrintMem(T& tensor)
{
constexpr auto spans = T::get_distributed_spans();
int counter = 0;
sweep_tile_span(spans[number<0>{}], [&](auto idxn) {
sweep_tile_span(spans[number<1>{}], [&](auto idxk) {
constexpr auto i_j_idx = make_tuple(idxn, idxk);
const auto tile_idx =
get_x_indices_from_distributed_indices(tensor.get_tile_distribution(), i_j_idx);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{});
printf("in G row is %d , col is %d, counter is %d, value is: %f"
" \n",
row,
col,
counter,
ck_tile::type_convert<float>(tensor(i_j_idx)));
counter = counter + 1;
}
});
});
}
template <typename AWindow, typename GWindow, typename DWindow, typename OWindow>
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
const GWindow& g_window_,
const DWindow& d_window_,
OWindow& o_window_,
TopkWeightDataType /*topk_weight*/,
CK_TILE_LDS_ADDR void* smem,
index_t hidden_size,
index_t intermediate_size)
{
ignore = d_window_;
ignore = hidden_size;
ignore = intermediate_size;
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
auto a_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsBlockDesc_A<Problem>());
auto a_lds_win = make_tile_window(
a_lds_view,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
auto a_global_to_dram_window = make_tile_window(
a_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
a_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_A<Problem>());
// load g to register
auto g_global_to_dram_window = make_tile_window(
g_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
g_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_G<Problem>());
// gemm gate
constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>();
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
// save tokens to lds
auto a_dram_block = load_tile(a_global_to_dram_window);
store_tile(a_lds_win, a_dram_block);
#if 0
PrintMem(a_dram_block);
#endif
auto g_dram_block = load_tile(g_global_to_dram_window);
#if 0
PrintMem(g_dram_block);
#endif
clear_tile(s_acc); // initialize C
constexpr index_t kK0 = BlockShape::Block_K0;
const index_t k0_loops = ck_tile::integer_divide_ceil(hidden_size, kK0);
index_t iCounter0 = k0_loops - 1;
while(iCounter0 > 0)
{
block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block);
block_sync_lds();
move_tile_window(a_global_to_dram_window, {0, kK0});
move_tile_window(g_global_to_dram_window, {0, kK0});
a_dram_block = load_tile(a_global_to_dram_window);
g_dram_block = load_tile(g_global_to_dram_window);
store_tile(a_lds_win, a_dram_block);
iCounter0--;
}
// tail
{
block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block);
}
// relu
const auto activation = ck_tile::element_wise::Gelu{};
tile_elementwise_inout(activation, s_acc, s_acc);
#if 0
PrintMem(s_acc);
#endif
// move sacc to LDS
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
auto bridge_slds_win =
make_tile_window(bridge_lds_view,
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
{0, 0});
// cast data to YDataType
auto y_pre = cast_tile<YDataType>(s_acc);
#if 0
PrintMem(y_pre);
#endif
// save to lds
store_tile(bridge_slds_win, y_pre);
block_sync_lds();
// gemm down
constexpr auto gemm_1 = Policy::template GetBlockGemm1<Problem>();
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
auto o_acc = OaccBlockTileType{};
// y data
auto bridge_llds_win =
make_tile_window(bridge_lds_view,
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
{0, 0},
Policy::template MakeYTileDistribution<Problem>());
auto y = load_tile(bridge_llds_win);
// d data
auto d_global_to_dram_window = make_tile_window(
d_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
d_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_D<Problem>());
auto d = load_tile(d_global_to_dram_window);
constexpr index_t kN1 = BlockShape::Block_N1;
const index_t n1_loops = ck_tile::integer_divide_ceil(hidden_size, kN1);
index_t iCounter1 = n1_loops - 1;
while(iCounter1 > 0)
{
clear_tile(o_acc);
block_sync_lds();
gemm_1(o_acc, y, d);
block_sync_lds();
move_tile_window(d_global_to_dram_window, {kN1, 0});
d = load_tile(d_global_to_dram_window);
// move out window and save data
auto o = cast_tile<ODataType>(o_acc);
store_tile(o_window_, o);
move_tile_window(o_window_, {kN1, 0});
iCounter1--;
}
// tail
{
clear_tile(o_acc);
block_sync_lds();
gemm_1(o_acc, y, d);
auto o = cast_tile<ODataType>(o_acc);
store_tile(o_window_, o);
}
#if 0
PrintMem(o_acc);
#endif
// store_tile(o_window_, a_dram_block);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
namespace ck_tile {
struct FusedMoeGemmPipelineGeneralPolicy
{
CK_TILE_HOST_DEVICE static constexpr index_t GetAsyncCopyDwords()
{
// TODO: always 1 dword
return 1;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A()
{
// using async
constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
constexpr index_t data_bytes = sizeof(typename Problem::ADataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G()
{
constexpr index_t copy_bytes = [&]() { return 16; }();
constexpr index_t data_bytes = sizeof(typename Problem::GDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D()
{
constexpr index_t copy_bytes = [&]() { return 16; }();
constexpr index_t data_bytes = sizeof(typename Problem::DDataType);
static_assert(copy_bytes % data_bytes == 0);
return copy_bytes / data_bytes;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_O()
{
if constexpr(Problem::Traits::OAtomic == 1)
{
// pack fp16/bf16 atomic
static_assert(sizeof(typename Problem::ODataType) == 2);
return 2;
}
else if constexpr(Problem::Traits::OAtomic == 2)
{
// fp32 atomic
return 1;
}
else
{
return 16 / sizeof(typename Problem::ODataType);
}
}
template <typename DataType_>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack()
{
// TODO: this is for 3d layout
return 16 / sizeof(remove_cvref_t<DataType_>);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_A()
{
return GetSmemKPack<typename Problem::ADataType>();
}
// used for bridge LDS shuffle
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_Y()
{
// TODO: this should match mfma layout
return 16 / sizeof(typename Problem::YDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
{
constexpr auto a_lds_desc = MakeLdsBlockDesc_A<Problem>();
return a_lds_desc.get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_Bridge()
{
constexpr auto bridge_lds_desc = MakeBridgeLdsBlockDesc<Problem>();
return bridge_lds_desc.get_element_space_size();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
constexpr index_t a_lds = GetSmemSize_A<Problem>();
constexpr index_t bridge_lds = GetSmemSize_Bridge<Problem>();
return max(a_lds, bridge_lds);
}
template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_SimpleMxK()
{
constexpr index_t K_vec = Alignment;
constexpr index_t K_rem = KPerBlock / K_vec;
if constexpr(get_warp_size() < K_rem)
{
static_assert(K_rem % get_warp_size() == 0);
constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
constexpr index_t K_wav = K_rem / get_warp_size();
static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet");
constexpr index_t M_wav = NumWarps / K_wav;
static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / M_wav;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav>, sequence<K_wav, K_lan, K_vec>>,
tuple<sequence<1, 2>, sequence<2>>,
tuple<sequence<1, 0>, sequence<1>>,
sequence<1, 2>,
sequence<0, 2>>{});
}
else
{
constexpr index_t K_lan = K_rem;
constexpr index_t M_lan = get_warp_size() / K_lan;
constexpr index_t M_wav = NumWarps;
static_assert(MPerBlock % (M_lan * M_wav) == 0,
"this tile size is too small please check");
constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<M_rep, M_wav, M_lan>, sequence<K_lan, K_vec>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_A()
{
constexpr index_t Block_M_ = Problem::BlockShape::Block_M0;
constexpr index_t Block_K_ = Problem::BlockShape::Block_K0;
constexpr index_t NumWarps_ = Problem::BlockShape::NumWarps;
constexpr index_t Alignment_ = GetAlignment_A<Problem>();
return MakeGlobalTileDistribution_SimpleMxK<Block_M_, Block_K_, NumWarps_, Alignment_>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_G()
{
using WG = decltype(GetWarpGemm0<Problem>());
using S_ = typename Problem::BlockShape;
static_assert(S_::WarpPerBlock_N0==4);
constexpr auto g_outer_dstr_enc = tile_distribution_encoding<
sequence<S_::WarpPerBlock_M0>,
tuple<sequence<S_::Repeat_N0, S_::WarpPerBlock_N0>, sequence<S_::Repeat_K0>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto g_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
g_outer_dstr_enc, typename WG::BWarpDstrEncoding{});
return make_static_tile_distribution(g_block_dstr_encode);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm0()
{
using S_ = typename Problem::BlockShape;
using GemmProblem = BlockGemmProblem<typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
S_::BlockSize,
TileGemmShape<typename S_::BlockTile_0,
typename S_::WarpPerBlock_0,
typename S_::WarpTile_0>>;
constexpr auto warp_gemm = GetWarpGemm0<Problem>();
using BlockGemmPolicy = BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
typename S_::WarpPerBlock_0,
decltype(warp_gemm)>;
return BlockGemmASmemBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm1()
{
using S_ = typename Problem::BlockShape;
using GemmProblem = BlockGemmProblem<typename Problem::YDataType,
typename Problem::DDataType,
typename Problem::AccDataType,
S_::BlockSize,
TileGemmShape<typename S_::BlockTile_1,
typename S_::WarpPerBlock_1,
typename S_::WarpTile_1>>;
constexpr auto warp_gemm = GetWarpGemm1<Problem>();
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
typename S_::WarpPerBlock_1,
decltype(warp_gemm)>;
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
}
// this is used as A matrix for 2nd gemm
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeYTileDistribution()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
constexpr auto y_outer_dstr_enc = tile_distribution_encoding<
sequence<1>,
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>, sequence<S_::WarpPerBlock_K1, S_::Repeat_K1>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{};
constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode);
return y_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_D()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
constexpr auto d_outer_dstr_enc = tile_distribution_encoding<
sequence<1>,
tuple<sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>, sequence<S_::WarpPerBlock_K1, S_::Repeat_K1>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{};
constexpr auto d_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
d_outer_dstr_enc, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto d_block_dstr = make_static_tile_distribution(d_block_dstr_encode);
return d_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
{
using S_ = remove_cvref_t<typename Problem::BlockShape>;
using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// using CDataType = typename WarpGemm::CDataType;
constexpr auto c_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>,
tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDesc_A()
{
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
constexpr index_t kK1 = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t kK0 = Block_K / kK1;
static_assert(Block_K % kK1 == 0);
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kK0>{}, number<Block_M>{}, number<kK1>{}),
make_tuple(number<(Block_M + 1) * kK1>{}, number<kK1>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<Block_M>{}),
make_merge_transform(make_tuple(number<kK0>{}, number<kK1>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsBlockDesc()
{
constexpr index_t Block_M = Problem::BlockShape::Block_M0;
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>();
constexpr index_t KPad = 0;
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
make_tuple(number<Block_N + KPad>{}, number<1>{}),
number<KVector>{},
number<1>{});
return desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0()
{
using S_ = typename Problem::BlockShape;
// A is vgpr, B is agpr. But since we transposed, so also need swap this
// TODO: this is ugly
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
// TODO: ugly
if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::fp16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<wg_ctrl>,
1>>{};
}
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
1>>{};
}
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
2>>{};
}
else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::int8_t> &&
std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<wg_ctrl>,
2>>{};
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1()
{
using S_ = typename Problem::BlockShape;
constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
// TODO: ugly
if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<wg_ctrl>,
1>>{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 8)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
1>>{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<wg_ctrl>,
2>>{};
}
else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::int8_t> &&
std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
{
return WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_i32_32x32x16_i8<wg_ctrl>,
2>>{};
}
}
};
} // namespace ck_tile
......@@ -6,6 +6,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockGemmARegBRegCRegV2
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
// M->N Warp
// constexpr auto a_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<NWarp>,
// tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
// tuple<sequence<1, 0>>,
// tuple<sequence<1, 0>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// constexpr auto b_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<MWarp>,
// tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
// tuple<sequence<0, 1>>,
// tuple<sequence<0, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
// constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
// constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// check ABC-block-distribution
// static_assert(
// std::is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
// remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
// .get_static_tile_distribution_encoding())>>,
// "A distribution is wrong!");
// static_assert(
// std::is_same_v<remove_cvref_t<decltype(b_block_dstr_encode)>,
// remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
// .get_static_tile_distribution_encoding())>>,
// "B distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor, b_block_tensor);
return c_block_tensor;
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment