Commit 7572a691 authored by coderfeli's avatar coderfeli
Browse files

merge develop

parents 7796fc73 6b6fcd37
#!/bin/sh #!/bin/sh
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"; do
for pr_i in "fp16" "bf16" ; do for pr_i in "fp16" "bf16" ; do
$EXE -prec=$pr_i -m=99 -n=13 for fadd in "0" "1"; do
$EXE -prec=$pr_i -m=17 -n=16 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
$EXE -prec=$pr_i -m=1 -n=100 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16
$EXE -prec=$pr_i -m=4 -n=128 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100
$EXE -prec=$pr_i -m=80 -n=127 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128
$EXE -prec=$pr_i -m=22 -n=255 -stride=256 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127
$EXE -prec=$pr_i -m=7 -n=599 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256
$EXE -prec=$pr_i -m=19 -n=512 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599
$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512
$EXE -prec=$pr_i -m=11 -n=510 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000
$EXE -prec=$pr_i -m=171 -n=676 -stride=818 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510
$EXE -prec=$pr_i -m=91 -n=636 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818
$EXE -prec=$pr_i -m=12 -n=768 -stride=800 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636
$EXE -prec=$pr_i -m=100 -n=766 -stride=812 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800
$EXE -prec=$pr_i -m=31 -n=1024 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812
$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024
$EXE -prec=$pr_i -m=8 -n=1501 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004
$EXE -prec=$pr_i -m=3 -n=1826 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501
$EXE -prec=$pr_i -m=5 -n=2040 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826
$EXE -prec=$pr_i -m=7 -n=2734 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040
$EXE -prec=$pr_i -m=1 -n=3182 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec=$pr_i -m=9 -n=4096 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec=$pr_i -m=3 -n=8192 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec=$pr_i -m=1 -n=10547 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
$EXE -prec=$pr_i -m=3 -n=17134 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done done
...@@ -63,17 +63,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -63,17 +63,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
assert(stride >= n); assert(x_stride >= n);
using XDataType = DataType; using XDataType = DataType;
using XScaleDataType = float; using SmoothScaleDataType = float;
using YScaleDataType = float; using YScaleDataType = float;
using QYDataType = ck_tile::int8_t; using QYDataType = ck_tile::int8_t;
using ComputeDataType = float; using ComputeDataType = float;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n}); ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
...@@ -82,15 +82,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -82,15 +82,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
xscale_buf.ToDevice(xscale_host.data()); smscale_buf.ToDevice(smscale_host.data());
constexpr bool kTwoPass = true; constexpr bool kTwoPass = true;
...@@ -101,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -101,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>; using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::SmoothquantPipelineProblem<XDataType, using Problem = ck_tile::SmoothquantPipelineProblem<XDataType,
XScaleDataType, SmoothScaleDataType,
ComputeDataType, ComputeDataType,
YScaleDataType, YScaleDataType,
QYDataType, QYDataType,
...@@ -115,7 +115,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -115,7 +115,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Kernel = ck_tile::Smoothquant<Pipeline>; using Kernel = ck_tile::Smoothquant<Pipeline>;
ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(), ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(),
xscale_buf.GetDeviceBuffer(), smscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
m, m,
...@@ -142,16 +142,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -142,16 +142,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
// smooth outlier // smooth outlier
{ {
auto f = [&](auto n_) { auto f = [&](auto n_) {
auto v_xscale = ck_tile::type_convert<ComputeDataType>(xscale_host(n_)); auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
for(int m_ = 0; m_ < m; ++m_) for(int m_ = 0; m_ < m; ++m_)
{ {
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_)); auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_xscale; y_host(m_, n_) = v_x * v_smscale;
} }
}; };
ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "smoothquant.hpp" #include "smoothquant.hpp"
...@@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a) ...@@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a)
using PipelineProblem = ck_tile::SmoothquantPipelineProblem< using PipelineProblem = ck_tile::SmoothquantPipelineProblem<
typename SmoothquantTypeConfig<DataType>::XDataType, typename SmoothquantTypeConfig<DataType>::XDataType,
typename SmoothquantTypeConfig<DataType>::XScaleDataType, typename SmoothquantTypeConfig<DataType>::SmoothScaleDataType,
typename SmoothquantTypeConfig<DataType>::ComputeDataType, typename SmoothquantTypeConfig<DataType>::ComputeDataType,
typename SmoothquantTypeConfig<DataType>::YScaleDataType, typename SmoothquantTypeConfig<DataType>::YScaleDataType,
typename SmoothquantTypeConfig<DataType>::QYDataType, typename SmoothquantTypeConfig<DataType>::QYDataType,
......
...@@ -66,15 +66,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -66,15 +66,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
using TypeConfig = SmoothquantTypeConfig<DataType>; using TypeConfig = SmoothquantTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using XScaleDataType = typename TypeConfig::XScaleDataType; using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType; using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType; using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n}); ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
...@@ -83,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -83,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
xscale_buf.ToDevice(xscale_host.data()); smscale_buf.ToDevice(smscale_host.data());
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride
...@@ -100,7 +100,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -100,7 +100,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
smoothquant_traits traits{data_type}; smoothquant_traits traits{data_type};
smoothquant_args args{x_buf.GetDeviceBuffer(), smoothquant_args args{x_buf.GetDeviceBuffer(),
xscale_buf.GetDeviceBuffer(), smscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
m, m,
...@@ -111,7 +111,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -111,7 +111,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
float ave_time = smoothquant( float ave_time = smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XScaleDataType) * n + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(SmoothScaleDataType) * n +
sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n; sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
...@@ -126,16 +126,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -126,16 +126,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
// smooth outlier // smooth outlier
{ {
auto f = [&](auto n_) { auto f = [&](auto n_) {
auto v_xscale = ck_tile::type_convert<ComputeDataType>(xscale_host(n_)); auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
for(int m_ = 0; m_ < m; ++m_) for(int m_ = 0; m_ < m; ++m_)
{ {
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_)); auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_xscale; y_host(m_, n_) = v_x * v_smscale;
} }
}; };
ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,21 +14,21 @@ struct SmoothquantTypeConfig; ...@@ -14,21 +14,21 @@ struct SmoothquantTypeConfig;
template <> template <>
struct SmoothquantTypeConfig<ck_tile::half_t> struct SmoothquantTypeConfig<ck_tile::half_t>
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using XScaleDataType = float; using SmoothScaleDataType = float;
using YScaleDataType = float; using YScaleDataType = float;
using QYDataType = ck_tile::int8_t; using QYDataType = ck_tile::int8_t;
using ComputeDataType = float; using ComputeDataType = float;
}; };
template <> template <>
struct SmoothquantTypeConfig<ck_tile::bf16_t> struct SmoothquantTypeConfig<ck_tile::bf16_t>
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using XScaleDataType = float; using SmoothScaleDataType = float;
using YScaleDataType = float; using YScaleDataType = float;
using QYDataType = ck_tile::int8_t; using QYDataType = ck_tile::int8_t;
using ComputeDataType = float; using ComputeDataType = float;
}; };
// runtime args // runtime args
......
...@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[]) ...@@ -26,6 +26,10 @@ auto create_args(int argc, char* argv[])
.insert("k", "4", "topk") .insert("k", "4", "topk")
.insert("unit", "32", "unit_size") .insert("unit", "32", "unit_size")
.insert("moe_buf_size", "0", "moe_buf_size") .insert("moe_buf_size", "0", "moe_buf_size")
.insert("local_eid",
"-1",
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
"please make sure eid is in ascending order!")
.insert("seed", "-1", "seed to be used, -1 means random every time") .insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "when set to 1 it will print kernel name") .insert("kname", "0", "when set to 1 it will print kernel name")
.insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("warmup", "5", "number of iterations before benchmark the kernel")
...@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -74,6 +78,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int kname = args.get_int("kname"); int kname = args.get_int("kname");
int warmup = args.get_int("warmup"); int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat"); int repeat = args.get_int("repeat");
int max_output_ids = int max_output_ids =
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size); ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
...@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -90,6 +95,30 @@ bool test_moe_sorting(ck_tile::ArgParser args)
return false; return false;
} }
bool local_expert_masking = args.get_str("local_eid") != "-1";
auto local_expert_masking_host = [&]() {
if(local_expert_masking)
{
auto local_eid = args.get_int_vec("local_eid");
// std::vector<int> v_ {num_experts, 0};
ck_tile::HostTensor<IndexType> v_{{num_experts}};
v_.SetZero();
for(auto eid : local_eid)
{
if(eid >= num_experts)
{
throw std::runtime_error(
"local_eid larger than number of expert, please check");
}
v_.mData[eid] = 1;
}
return v_;
}
else
// return std::vector<int>{};
return ck_tile::HostTensor<IndexType>{{1}};
}();
// tokens already considered batch size // tokens already considered batch size
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1}); ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1}); ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
...@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -111,6 +140,8 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_expert_ids_host.get_element_space_size_in_bytes()); sorted_expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem local_expert_masking_dev(
local_expert_masking_host.get_element_space_size_in_bytes());
topk_ids_dev.ToDevice(topk_ids_host.data()); topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data()); weights_dev.ToDevice(weights_host.data());
...@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -118,11 +149,15 @@ bool test_moe_sorting(ck_tile::ArgParser args)
{ {
moe_buf_dev.ToDevice(moe_buf_host.data()); moe_buf_dev.ToDevice(moe_buf_host.data());
} }
if(local_expert_masking)
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
moe_sorting_trait trait{index_prec, weight_prec}; moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking};
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
weights_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(),
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
: nullptr,
sorted_ids_dev.GetDeviceBuffer(), sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(), sorted_expert_ids_dev.GetDeviceBuffer(),
...@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -140,15 +175,22 @@ bool test_moe_sorting(ck_tile::ArgParser args)
warmup, warmup,
repeat}; repeat};
auto ms = moe_sorting(trait, karg, sc); auto ms = moe_sorting(trait, karg, sc);
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ", printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ",
index_prec.c_str(), index_prec.c_str(),
weight_prec.c_str(), weight_prec.c_str(),
tokens, tokens,
num_experts, num_experts,
topk, topk);
ms);
if(local_expert_masking)
{
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
}
if(ms < 0) if(ms < 0)
printf("not supported\n"); printf("not supported\n");
else
printf("ms:%f, ", ms);
fflush(stdout); fflush(stdout);
if(ms < 0) if(ms < 0)
{ {
...@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -174,12 +216,14 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int32_t ref_total_tokens_post_pad = 0; int32_t ref_total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host, ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host, weights_host,
local_expert_masking_host,
sorted_ids_ref, sorted_ids_ref,
sorted_weights_ref, sorted_weights_ref,
sorted_expert_ids_ref, sorted_expert_ids_ref,
ref_total_tokens_post_pad, ref_total_tokens_post_pad,
num_experts, num_experts,
unit_size); unit_size,
local_expert_masking);
rtn &= ck_tile::check_err( rtn &= ck_tile::check_err(
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6); sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
rtn &= ck_tile::check_err(sorted_weights_host, rtn &= ck_tile::check_err(sorted_weights_host,
...@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args) ...@@ -199,9 +243,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
} }
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]; rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
printf("total_tokens_post_pad:%d(%d), ",
ref_total_tokens_post_pad,
sorted_id_cnt_host.mData[0]);
} }
printf("valid:%s\n", rtn ? "y" : "n"); printf("valid:%s", rtn ? "y" : "n");
fflush(stdout);
if(!rtn)
printf(", (%d)", seed);
printf("\n");
fflush(stdout); fflush(stdout);
return rtn; return rtn;
} }
......
...@@ -3,6 +3,12 @@ ...@@ -3,6 +3,12 @@
#include "moe_sorting_api.hpp" #include "moe_sorting_api.hpp"
#ifndef MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_USE_EX_KERNEL 1
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \ #define MOE_SORTING_DISPATCH_ETILE(unroll_num_, expert_tile_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \ constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr ck_tile::index_t expert_tile = expert_tile_; \ constexpr ck_tile::index_t expert_tile = expert_tile_; \
...@@ -17,6 +23,67 @@ ...@@ -17,6 +23,67 @@
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time; return ave_time;
#else
#define MOE_SORTING_DISPATCH_(sub_token_tile_, sub_token_onshot_, local_expert_masking_) \
constexpr ck_tile::index_t sub_token_tile = sub_token_tile_; \
constexpr bool sub_token_onshot = sub_token_onshot_; \
constexpr bool local_expert_masking = local_expert_masking_; \
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
ms_weight_type, \
sub_token_tile, \
sub_token_onshot, \
local_expert_masking>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
#define MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, sub_token_onshot_, local_expert_masking_) \
if(row_ % 8 == 0) \
{ \
MOE_SORTING_DISPATCH_(8, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 4 == 0) \
{ \
MOE_SORTING_DISPATCH_(4, sub_token_onshot_, local_expert_masking_); \
} \
else if(row_ % 2 == 0) \
{ \
MOE_SORTING_DISPATCH_(2, sub_token_onshot_, local_expert_masking_); \
} \
else \
{ \
MOE_SORTING_DISPATCH_(1, sub_token_onshot_, local_expert_masking_); \
}
#define MOE_SORTING_DISPATCH_SUBTO_(row_, local_expert_masking_) \
if(is_sub_token_onshot) \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, true, local_expert_masking_) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUB_TOKEN_(row_, false, local_expert_masking_) \
}
#define MOE_SORTING_DISPATCH_EMASK_(row_) \
if(is_local_expert_masking) \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, true) \
} \
else \
{ \
MOE_SORTING_DISPATCH_SUBTO_(row_, false) \
}
#endif
#if !MOE_SORTING_USE_EX_KERNEL
#define MOE_SORTING_DISPATCH(unroll_num_) \ #define MOE_SORTING_DISPATCH(unroll_num_) \
if(a.num_experts <= 8) \ if(a.num_experts <= 8) \
{ \ { \
...@@ -38,11 +105,13 @@ ...@@ -38,11 +105,13 @@
{ \ { \
MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \ MOE_SORTING_DISPATCH_ETILE(unroll_num_, 0) \
} }
#endif
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{ {
if(t.weight_type == "fp32" && t.index_type == "int32") if(t.weight_type == "fp32" && t.index_type == "int32")
{ {
#if !MOE_SORTING_USE_EX_KERNEL
if(a.num_experts > 127) if(a.num_experts > 127)
{ {
printf("lds size exceed, only support experts <127 \n"); printf("lds size exceed, only support experts <127 \n");
...@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi ...@@ -83,6 +152,19 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
MOE_SORTING_DISPATCH(4); MOE_SORTING_DISPATCH(4);
} }
} }
#else
using index_t = ck_tile::index_t;
using ms_weight_type = float;
auto [r_, c_] = ck_tile::moe_sorting_get_smem_row_col(a.tokens, a.num_experts);
auto sub_token_ = r_ - 2;
r_ = (r_ - 2) / 8;
bool is_sub_token_onshot = a.tokens <= sub_token_;
bool is_local_expert_masking = t.local_expert_masking;
(void)c_;
MOE_SORTING_DISPATCH_EMASK_(r_);
// MOE_SORTING_DISPATCH_ETILE(0, 0);
#endif
} }
return -1; return -1;
} }
...@@ -10,7 +10,8 @@ ...@@ -10,7 +10,8 @@
struct moe_sorting_trait struct moe_sorting_trait
{ {
std::string index_type; std::string index_type;
std::string weight_type; // currently always float std::string weight_type; // currently always float
bool local_expert_masking; // if mask experts as local expert
}; };
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
......
...@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11 ...@@ -17,4 +17,12 @@ $EXE -t=71 -e=11 -k=11
$EXE -t=1 -e=1 -k=1 $EXE -t=1 -e=1 -k=1
$EXE -t=99 -e=2 -k=1 $EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13 $EXE -t=333 -e=99 -k=13
$EXE -t=11 -e=256 -k=5
$EXE -t=64 -e=455 -k=8
$EXE -t=777 -e=802 -k=99
$EXE -t=4097 -e=906 -k=51
$EXE -t=128 -e=32 -k=5 -moe_buf_size=262144 $EXE -t=128 -e=32 -k=5 -moe_buf_size=262144
$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11
$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19
$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33
$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129
...@@ -15,8 +15,13 @@ template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true ...@@ -15,8 +15,13 @@ template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true, false>>(const S&, A);
#endif #endif
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
// clang-format on // clang-format on
...@@ -6,8 +6,13 @@ ...@@ -6,8 +6,13 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd 2p // rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
// clang-format on // clang-format on
...@@ -6,9 +6,14 @@ ...@@ -6,9 +6,14 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd 2p // rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
// clang-format on // clang-format on
...@@ -6,7 +6,11 @@ ...@@ -6,7 +6,11 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd 2p // rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
// clang-format on // clang-format on
...@@ -6,9 +6,13 @@ ...@@ -6,9 +6,13 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd 2p // rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 128, 8, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 4, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 2, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 6, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 1024, 1, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 1024, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 6, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on // clang-format on
...@@ -6,9 +6,13 @@ ...@@ -6,9 +6,13 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd 2p // rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on // clang-format on
...@@ -6,9 +6,13 @@ ...@@ -6,9 +6,13 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd 2p // rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, true>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, true>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
// clang-format on // clang-format on
...@@ -6,8 +6,13 @@ ...@@ -6,8 +6,13 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd 2p // rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 8, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 4, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 2, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 1, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 4, 64, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 4, 64, 1, true, false>>(const S&, A);
// clang-format on // clang-format on
...@@ -6,7 +6,11 @@ ...@@ -6,7 +6,11 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd 2p // rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 1, true , false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 1, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 2, true , false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 2, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 1, true , false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 4, 64, 1, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 1, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 2, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 4, 64, 1, true , false>>(const S&, A);
// clang-format on // clang-format on
...@@ -6,7 +6,11 @@ ...@@ -6,7 +6,11 @@
// clang-format off // clang-format off
// rm rn tm tn vn pd 2p // rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 4, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 4, 64, 2, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 6, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, 1, 12, 4, 64, 1, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 12, 4, 64, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 6, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 12, 4, 64, 1, true, false>>(const S&, A);
// clang-format on // clang-format on
...@@ -15,8 +15,13 @@ template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true ...@@ -15,8 +15,13 @@ template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true ,false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true ,false>>(const S&, A);
#endif #endif
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 2, 128, 8, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 2, 128, 4, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 2, 128, 2, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 1, true, false>>(const S&, A); template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
// clang-format on // clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment