Commit 019d4b7c authored by illsilin's avatar illsilin
Browse files

merge from public repo

parents 07307ea1 5063a39f
...@@ -3,18 +3,42 @@ ...@@ -3,18 +3,42 @@
#include "moe_sorting_api.hpp" #include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \ using ms_problem = \
auto kargs = kernel::MakeKargs(a); \ ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
const dim3 grids = kernel::GridSize(a); \ using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
const dim3 blocks = kernel::BlockSize(a); \ auto kargs = kernel::MakeKargs(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \ const dim3 grids = kernel::GridSize(a); \
float ave_time = ck_tile::launch_kernel( \ const dim3 blocks = kernel::BlockSize(a); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
...@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi ...@@ -49,21 +73,12 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
case(6): { case(6): {
MOE_SORTING_DISPATCH(6); MOE_SORTING_DISPATCH(6);
} }
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): { case(8): {
MOE_SORTING_DISPATCH(8); MOE_SORTING_DISPATCH(8);
} }
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): { case(10): {
MOE_SORTING_DISPATCH(10); MOE_SORTING_DISPATCH(10);
} }
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: { default: {
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
......
...@@ -16,4 +16,5 @@ $EXE -t=127 -e=99 -k=19 ...@@ -16,4 +16,5 @@ $EXE -t=127 -e=99 -k=19
$EXE -t=71 -e=11 -k=11 $EXE -t=71 -e=11 -k=11
$EXE -t=1 -e=1 -k=1 $EXE -t=1 -e=1 -k=1
$EXE -t=99 -e=2 -k=1 $EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13 $EXE -t=333 -e=99 -k=13
\ No newline at end of file $EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "moe_smoothquant.hpp" #include "moe_smoothquant.hpp"
...@@ -35,7 +35,7 @@ float moe_smoothquant_(const S& s, A a) ...@@ -35,7 +35,7 @@ float moe_smoothquant_(const S& s, A a)
using PipelineProblem = ck_tile::SmoothquantPipelineProblem< using PipelineProblem = ck_tile::SmoothquantPipelineProblem<
typename MoeSmoothquantTypeConfig<DataType>::XDataType, typename MoeSmoothquantTypeConfig<DataType>::XDataType,
typename MoeSmoothquantTypeConfig<DataType>::XScaleDataType, typename MoeSmoothquantTypeConfig<DataType>::SmoothScaleDataType,
typename MoeSmoothquantTypeConfig<DataType>::ComputeDataType, typename MoeSmoothquantTypeConfig<DataType>::ComputeDataType,
typename MoeSmoothquantTypeConfig<DataType>::YScaleDataType, typename MoeSmoothquantTypeConfig<DataType>::YScaleDataType,
typename MoeSmoothquantTypeConfig<DataType>::QYDataType, typename MoeSmoothquantTypeConfig<DataType>::QYDataType,
......
...@@ -91,15 +91,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -91,15 +91,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
using TypeConfig = MoeSmoothquantTypeConfig<DataType>; using TypeConfig = MoeSmoothquantTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using XScaleDataType = typename TypeConfig::XScaleDataType; using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType; using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType; using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor<XDataType> x_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({experts * hidden_size}); ck_tile::HostTensor<SmoothScaleDataType> smscale_host({experts * hidden_size});
ck_tile::HostTensor<ck_tile::index_t> topk_ids_host({tokens, topk}); ck_tile::HostTensor<ck_tile::index_t> topk_ids_host({tokens, topk});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({topk * tokens}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({topk * tokens}, {1});
...@@ -110,16 +110,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -110,16 +110,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
topid_unique_gen<ck_tile::index_t>(topk_ids_host.mData, tokens, topk, experts, 11937); topid_unique_gen<ck_tile::index_t>(topk_ids_host.mData, tokens, topk, experts, 11937);
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem topk_ids_buf(topk_ids_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem topk_ids_buf(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
xscale_buf.ToDevice(xscale_host.data()); smscale_buf.ToDevice(smscale_host.data());
topk_ids_buf.ToDevice(topk_ids_host.data()); topk_ids_buf.ToDevice(topk_ids_host.data());
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
...@@ -129,7 +129,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -129,7 +129,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
moe_smoothquant_traits traits{data_type}; moe_smoothquant_traits traits{data_type};
moe_smoothquant_args args{x_buf.GetDeviceBuffer(), moe_smoothquant_args args{x_buf.GetDeviceBuffer(),
xscale_buf.GetDeviceBuffer(), smscale_buf.GetDeviceBuffer(),
topk_ids_buf.GetDeviceBuffer(), topk_ids_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
...@@ -143,9 +143,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -143,9 +143,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
float ave_time = moe_smoothquant( float ave_time = moe_smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = std::size_t num_byte = sizeof(XDataType) * tokens * hidden_size +
sizeof(XDataType) * tokens * hidden_size + sizeof(XScaleDataType) * topk * hidden_size + sizeof(SmoothScaleDataType) * topk * hidden_size +
sizeof(YScaleDataType) * topk * tokens + sizeof(QYDataType) * topk * tokens * hidden_size; sizeof(YScaleDataType) * topk * tokens +
sizeof(QYDataType) * topk * tokens * hidden_size;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
...@@ -165,11 +166,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -165,11 +166,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
for(int i_h = 0; i_h < hidden_size; ++i_h) for(int i_h = 0; i_h < hidden_size; ++i_h)
{ {
auto v_xscale = ck_tile::type_convert<ComputeDataType>( auto v_smscale = ck_tile::type_convert<ComputeDataType>(
xscale_host(i_expert * hidden_size + i_h)); smscale_host(i_expert * hidden_size + i_h));
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(i_token, i_h)); auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(i_token, i_h));
// y_host(i_token * topk + i_topk, i_h) = v_x * v_xscale; // y_host(i_token * topk + i_topk, i_h) = v_x * v_smscale;
y_host(i_topk * tokens + i_token, i_h) = v_x * v_xscale; y_host(i_topk * tokens + i_token, i_h) = v_x * v_smscale;
} }
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,21 +14,21 @@ struct MoeSmoothquantTypeConfig; ...@@ -14,21 +14,21 @@ struct MoeSmoothquantTypeConfig;
template <> template <>
struct MoeSmoothquantTypeConfig<ck_tile::half_t> struct MoeSmoothquantTypeConfig<ck_tile::half_t>
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using XScaleDataType = float; using SmoothScaleDataType = float;
using YScaleDataType = float; using YScaleDataType = float;
using QYDataType = ck_tile::int8_t; using QYDataType = ck_tile::int8_t;
using ComputeDataType = float; using ComputeDataType = float;
}; };
template <> template <>
struct MoeSmoothquantTypeConfig<ck_tile::bf16_t> struct MoeSmoothquantTypeConfig<ck_tile::bf16_t>
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using XScaleDataType = float; using SmoothScaleDataType = float;
using YScaleDataType = float; using YScaleDataType = float;
using QYDataType = ck_tile::int8_t; using QYDataType = ck_tile::int8_t;
using ComputeDataType = float; using ComputeDataType = float;
}; };
// runtime args // runtime args
......
...@@ -8,6 +8,9 @@ The benifit of this fused-moe: ...@@ -8,6 +8,9 @@ The benifit of this fused-moe:
* much less kernel instance, easy to maintain * much less kernel instance, easy to maintain
# Implementation and feature support # Implementation and feature support
## NOTES:
currently gate+up in fp16 case will very easily cause accumulator overflow the fp16 max(65504), hence result in INF. Please use BF16 for gate+up case, API side will have no check for this.
## moe-sorting ## moe-sorting
this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic) this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic)
......
...@@ -26,7 +26,7 @@ struct fused_moe_args ...@@ -26,7 +26,7 @@ struct fused_moe_args
ck_tile::index_t block_m; // block_m, used to devide the input ck_tile::index_t block_m; // block_m, used to devide the input
ck_tile::index_t hidden_size; // k ck_tile::index_t hidden_size; // k
ck_tile::index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2 ck_tile::index_t intermediate_size; // n / TP, for Gate. and Up, Down is also this value
ck_tile::index_t num_tokens; // input number of tokens for current iteration ck_tile::index_t num_tokens; // input number of tokens for current iteration
ck_tile::index_t num_experts; // number of groups ck_tile::index_t num_experts; // number of groups
ck_tile::index_t topk; // need this? ck_tile::index_t topk; // need this?
...@@ -45,7 +45,8 @@ struct fused_moe_traits ...@@ -45,7 +45,8 @@ struct fused_moe_traits
std::string prec_sq; // smooth quant scale std::string prec_sq; // smooth quant scale
std::string prec_kw; // topk-weight data type std::string prec_kw; // topk-weight data type
int block_m; int block_m;
int gate_only; int activation; // 0:gelu, 1:silu
int gate_only; // 0:g1u0, 1:g1u1
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
}; };
......
...@@ -77,7 +77,8 @@ struct fused_moegemm_traits ...@@ -77,7 +77,8 @@ struct fused_moegemm_traits
std::string prec_sq; // smooth quant scale std::string prec_sq; // smooth quant scale
std::string prec_kw; // topk-weight data type std::string prec_kw; // topk-weight data type
int block_m; int block_m;
int gate_only; int activation; // 0:gelu, 1:silu
int gate_only; // 0:g1u0, 1:g1u1
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
}; };
......
...@@ -41,6 +41,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf ...@@ -41,6 +41,7 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
t.prec_sq, t.prec_sq,
t.prec_kw, t.prec_kw,
t.block_m, t.block_m,
t.activation,
t.gate_only, t.gate_only,
t.fused_quant}; t.fused_quant};
auto a1 = fused_moegemm_args{ auto a1 = fused_moegemm_args{
......
...@@ -17,15 +17,67 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile: ...@@ -17,15 +17,67 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
// clang-format off // clang-format off
float r = -1; float r = -1;
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" && if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
{ {
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 0)
{
constexpr ck_tile::index_t act_ = 0;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a);
}
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1 && t.activation == 1)
{
constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 1;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" && else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1) t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 0 && t.activation == 1)
{ {
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>; constexpr ck_tile::index_t act_ = 1;
constexpr ck_tile::index_t go_ = 0;
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, act_, go_, 0>;
r = fused_moegemm_<t_>(s, a); r = fused_moegemm_<t_>(s, a);
} }
// clang-format on // clang-format on
......
...@@ -21,21 +21,31 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) ...@@ -21,21 +21,31 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
typename Ts_::BlockTile_1, typename Ts_::BlockTile_1,
typename Ts_::WarpPerBlock_0, typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0>; typename Ts_::WarpTile_0>;
using f_problem =
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType, constexpr auto get_activation_ = []() {
typename Ts_::GDataType, if constexpr(Ts_::Activation == 0)
typename Ts_::DDataType, {
typename Ts_::AccDataType, return ck_tile::element_wise::FastGeluAsm{};
typename Ts_::ODataType, }
typename Ts_::AScaleDataType, else
typename Ts_::GScaleDataType, return ck_tile::element_wise::Silu{};
typename Ts_::DScaleDataType, };
typename Ts_::YSmoothScaleDataType, using f_act_ = ck_tile::remove_cvref_t<decltype(get_activation_())>;
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType, using f_problem = ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
ck_tile::element_wise::FastGeluAsm, // TODO: hardcoded typename Ts_::GDataType,
f_shape, typename Ts_::DDataType,
f_traits>; typename Ts_::AccDataType,
typename Ts_::ODataType,
typename Ts_::AScaleDataType,
typename Ts_::GScaleDataType,
typename Ts_::DScaleDataType,
typename Ts_::YSmoothScaleDataType,
typename Ts_::TopkWeightDataType,
typename Ts_::IndexDataType,
f_act_, // TODO: hardcoded
f_shape,
f_traits>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>; // using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>; using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
......
...@@ -15,7 +15,8 @@ template <typename I, ...@@ -15,7 +15,8 @@ template <typename I,
typename KW, typename KW,
typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down> typename BlockTIle_, // seq<b_token, b_interm, b_hidden, b_down>
typename WarpPerBlock_, typename WarpPerBlock_,
typename WarpTile_, // seq<*,*,*>, used to select mfma typename WarpTile_, // seq<*,*,*>, used to select mfma
ck_tile::index_t Activation_ = 0, // 0: Gelu 1: Silu
ck_tile::index_t GateOnly_ = 0, ck_tile::index_t GateOnly_ = 0,
ck_tile::index_t FusedQuant_ = 0> ck_tile::index_t FusedQuant_ = 0>
struct fmoe_ // traits, ugly name, only used for internal struct fmoe_ // traits, ugly name, only used for internal
...@@ -44,10 +45,11 @@ struct fmoe_ // traits, ugly name, only used for internal ...@@ -44,10 +45,11 @@ struct fmoe_ // traits, ugly name, only used for internal
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>; using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>; using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>; using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>; using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
static constexpr ck_tile::index_t Activation = Activation_; // 0: Gelu 1: Silu
static constexpr ck_tile::index_t GateOnly = GateOnly_; static constexpr ck_tile::index_t GateOnly = GateOnly_;
static constexpr ck_tile::index_t FusedQuant = FusedQuant_; static constexpr ck_tile::index_t FusedQuant = FusedQuant_;
}; };
...@@ -8,7 +8,18 @@ ...@@ -8,7 +8,18 @@
// clang-format off // clang-format off
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -8,7 +8,19 @@ ...@@ -8,7 +8,19 @@
// clang-format off // clang-format off
template float fused_moegemm_< template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0> fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 0, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a);
template float fused_moegemm_<
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 1, 0>
>(const ck_tile::stream_config& s, fused_moegemm_args a); >(const ck_tile::stream_config& s, fused_moegemm_args a);
// clang-format on // clang-format on
...@@ -3,18 +3,42 @@ ...@@ -3,18 +3,42 @@
#include "fused_moesorting.hpp" #include "fused_moesorting.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \ using ms_problem = \
auto kargs = kernel::MakeKargs(a); \ ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num, expert_tile>; \
const dim3 grids = kernel::GridSize(a); \ using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
const dim3 blocks = kernel::BlockSize(a); \ auto kargs = kernel::MakeKargs(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \ const dim3 grids = kernel::GridSize(a); \
float ave_time = ck_tile::launch_kernel( \ const dim3 blocks = kernel::BlockSize(a); \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 8) \
} \
else if(a.num_experts <= 16) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 16) \
} \
else if(a.num_experts <= 32) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 32) \
} \
else if(a.num_experts <= 64) \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 64) \
} \
else \
{ \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
}
float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s) float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
...@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til ...@@ -49,21 +73,12 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
case(6): { case(6): {
MOE_SORTING_DISPATCH(6); MOE_SORTING_DISPATCH(6);
} }
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): { case(8): {
MOE_SORTING_DISPATCH(8); MOE_SORTING_DISPATCH(8);
} }
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): { case(10): {
MOE_SORTING_DISPATCH(10); MOE_SORTING_DISPATCH(10);
} }
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: { default: {
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
......
...@@ -108,12 +108,14 @@ auto create_args(int argc, char* argv[]) ...@@ -108,12 +108,14 @@ auto create_args(int argc, char* argv[])
.insert( .insert(
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate") "gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
.insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm") .insert("api", "0", "benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm")
.insert("act", "0", "activation after first gemm. 0:gelu, 1:silu")
.insert("balance", .insert("balance",
"0", "0",
"if set to 1, will try balance the expert in topk-ids(convenient for testing)") "if set to 1, will try balance the expert in topk-ids(convenient for testing)")
.insert("init", .insert("init",
"2", "1",
"init method. 0:random stepped float(fast). 1: random uniform, 2:rand normalized" "init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand "
"normalized[0, 1]"
"normalized(slow)") "normalized(slow)")
.insert("seed", "11939", "seed used to do random") .insert("seed", "11939", "seed used to do random")
.insert("warmup", "5", "cold iter") .insert("warmup", "5", "cold iter")
...@@ -135,6 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -135,6 +137,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::index_t intermediate_size = arg_parser.get_int("i"); ck_tile::index_t intermediate_size = arg_parser.get_int("i");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t stride = arg_parser.get_int("stride");
ck_tile::index_t block_m = arg_parser.get_int("bm"); ck_tile::index_t block_m = arg_parser.get_int("bm");
ck_tile::index_t activation = arg_parser.get_int("act");
if(stride < 0) if(stride < 0)
stride = hidden_size; stride = hidden_size;
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
...@@ -194,11 +197,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -194,11 +197,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
return std::string(", st:") + std::to_string(stride); return std::string(", st:") + std::to_string(stride);
}(); }();
std::cout << "[" << api_str << "|" << prec_str << "]" std::cout
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str << "[" << api_str << "|" << prec_str << "]"
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp << " t:" << tokens << ", e:" << experts << ", k:" << topk << stride_str
<< ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1 << ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush; << ", act:"
<< activation
// << ", shrd_interm:" << shared_intermediate_size_0 << "|" << shared_intermediate_size_1
<< (gate_only ? ", g1u0" : ", g1u1") << ", q:" << fused_quant << std::flush;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>; using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType; using ADataType = typename TypeConfig::ADataType;
...@@ -370,6 +376,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -370,6 +376,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sq, prec_sq,
prec_kw, prec_kw,
block_m, block_m,
activation,
gate_only, gate_only,
fused_quant}; fused_quant};
...@@ -389,7 +396,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -389,7 +396,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_buf.GetDeviceBuffer(), num_sorted_tiles_buf.GetDeviceBuffer(),
block_m, block_m,
hidden_size, hidden_size,
shared_intermediate_size_0, intermediate_size / tp,
tokens, tokens,
experts, experts,
topk, topk,
...@@ -408,6 +415,28 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -408,6 +415,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< cal_tbps(ave_time) << " TB/s" << std::flush; << cal_tbps(ave_time) << " TB/s" << std::flush;
bool pass = true; bool pass = true;
#define CPU_FUSED_MOE(act_type_) \
ck_tile::reference_fused_moe<AccDataType, act_type_>(a_host, \
g_host, \
d_host, \
sa_host, \
sg_host, \
sd_host, \
sy_host, \
o_host, \
sorted_token_ids_host, \
sorted_weight_host, \
sorted_expert_ids_host, \
num_sorted_tiles_host, \
topk_ids_host, \
block_m, \
tokens, \
experts, \
hidden_size, \
intermediate_size / tp, \
topk, \
gate_only)
if(do_validation) if(do_validation)
{ {
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>( ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
...@@ -419,28 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -419,28 +448,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
num_sorted_tiles_host.mData[0], num_sorted_tiles_host.mData[0],
experts, experts,
block_m); block_m);
if(activation == 0)
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>( {
a_host, CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
g_host, }
d_host, else
sa_host, {
sg_host, CPU_FUSED_MOE(ck_tile::element_wise::Silu);
sd_host, }
sy_host,
o_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host,
topk_ids_host,
block_m,
tokens,
experts,
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float"); // o_dev.savetxt("gpu-out.txt", "float");
...@@ -491,6 +506,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -491,6 +506,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
prec_sq, prec_sq,
prec_kw, prec_kw,
block_m, block_m,
activation,
gate_only, gate_only,
fused_quant}; fused_quant};
...@@ -507,7 +523,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -507,7 +523,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
sorted_expert_ids_buf.GetDeviceBuffer(), sorted_expert_ids_buf.GetDeviceBuffer(),
num_sorted_tiles_buf.GetDeviceBuffer(), num_sorted_tiles_buf.GetDeviceBuffer(),
hidden_size, hidden_size,
shared_intermediate_size_0, intermediate_size / tp,
tokens, tokens,
experts, experts,
topk, topk,
...@@ -529,27 +545,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -529,27 +545,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation) if(do_validation)
{ {
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>( if(activation == 0)
a_host, {
g_host, CPU_FUSED_MOE(ck_tile::element_wise::Gelu);
d_host, }
sa_host, else
sg_host, {
sd_host, CPU_FUSED_MOE(ck_tile::element_wise::Silu);
sy_host, }
o_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host,
topk_ids_host,
block_m,
tokens,
experts,
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
// o_dev.savetxt("gpu-out.txt", "float"); // o_dev.savetxt("gpu-out.txt", "float");
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "batched_gemm.hpp" #include "batched_gemm.hpp"
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& s) float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
{ {
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadM = false; constexpr bool kPadM = false;
...@@ -70,20 +70,25 @@ float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config& ...@@ -70,20 +70,25 @@ float batched_gemm(const batched_gemm_kargs& args, const ck_tile::stream_config&
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>; using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>; using Kernel = ck_tile::BatchedGemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(args); auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
constexpr dim3 blocks = Kernel::BlockSize(); constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Launching kernel with args:" std::cout << "Launching kernel with args:"
......
...@@ -29,10 +29,6 @@ using BDataType = Types::BDataType; ...@@ -29,10 +29,6 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType; using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType; using CDataType = Types::CDataType;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
{
};
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
...@@ -53,11 +49,12 @@ auto create_args(int argc, char* argv[]) ...@@ -53,11 +49,12 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer"); .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
} }
// host API // host API
float batched_gemm(batched_gemm_kargs args, const ck_tile::stream_config& s); float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s);
...@@ -17,13 +17,15 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -17,13 +17,15 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::index_t batch_stride_B, ck_tile::index_t batch_stride_B,
ck_tile::index_t batch_stride_C, ck_tile::index_t batch_stride_C,
ck_tile::index_t batch_count, ck_tile::index_t batch_count,
ck_tile::index_t kbatch,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
batched_gemm_kargs args; ck_tile::BatchedGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M; args.M = M;
args.N = N; args.N = N;
args.K = K; args.K = K;
...@@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -79,6 +81,7 @@ int run_batched_gemm_example_with_layouts(int argc,
ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b"); ck_tile::index_t batch_stride_B = arg_parser.get_int("batch_stride_b");
ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c"); ck_tile::index_t batch_stride_C = arg_parser.get_int("batch_stride_c");
ck_tile::index_t batch_count = arg_parser.get_int("batch_count"); ck_tile::index_t batch_count = arg_parser.get_int("batch_count");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup"); int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat"); int n_repeat = arg_parser.get_int("repeat");
...@@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -159,6 +162,7 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_B, batch_stride_B,
batch_stride_C, batch_stride_C,
batch_count, batch_count,
kbatch,
n_warmup, n_warmup,
n_repeat); n_repeat);
...@@ -188,15 +192,33 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -188,15 +192,33 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref.SetZero(); c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType)));
ck_tile::hip_check_error(hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ck_tile::reference_batched_gemm_gpu<ADataType, ck_tile::reference_batched_gemm_gpu<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
CDataType, CDataType,
ALayout, ALayout,
BLayout, BLayout,
CLayout>(a_m_k_dev_buf, CLayout>(d_A,
b_k_n_dev_buf, d_B,
c_m_n_gpu_buf_ref, d_C,
M, M,
N, N,
K, K,
...@@ -208,6 +230,15 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -208,6 +230,15 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_C, batch_stride_C,
batch_count); batch_count);
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_A));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
......
...@@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; ...@@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("a_layout", "R", "A tensor data layout - Row by default") arg_parser.insert("Ms", "", "M dimensions - empty by default.")
.insert("b_layout", "R", "B tensor data layout - Row by default") .insert("Ns", "", "N dimensions - empty by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default") .insert("Ks", "", "K dimensions - empty by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU") .insert("stride_As", "", "Tensor A strides - it is empty by default.")
.insert("warmup", "10", "number of iterations before benchmark the kernel") .insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
.insert("group_count", "16", "group count"); .insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "R", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "16", "group count.");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment