Unverified Commit 3c5717df authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into gemm_elementwise_gemm

parents 171b9030 d9f1ead3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 1, 128, 8,true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 1, 256, 4,true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 6, 1, 256, 2,true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 1, 1024, 1,true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 1, 128, 8,true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 1, 256, 4,true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 6, 1, 256, 2,true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 1, 1024, 1,true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 8, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 4, 64, 4, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 4, 64, 2, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 4, 64, 1, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 8, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 4, 64, 4, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 4, 64, 2, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 4, 64, 1, true , false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 4, 64, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 1, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 2, true, false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 4, 64, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 4, 64, 4, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 6, 4, 64, 2, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 12, 4, 64, 1, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 4, 64, 4, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 6, 4, 64, 2, true , false>>(const S&, A);
template float moe_smoothquant_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 12, 4, 64, 1, true , false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "moe_smoothquant.hpp"
template <typename InType,
typename OutType,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kTwoPass_>
using trait_ = moe_smoothquant_traits_<InType,
OutType,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kTwoPass_>;
template <typename in_type, typename out_type>
float moe_smoothquant_dispatch(moe_smoothquant_traits /*t*/,
moe_smoothquant_args a,
const ck_tile::stream_config& s)
{
float r = -1;
// clang-format off
// rm rn tm tn vn pd 2p
if(a.hidden_size <= 64) {
r = moe_smoothquant_<trait_<in_type, out_type, 1, 1, 4, 64, 1, true, false>>(s, a);
}
else if(a.hidden_size <= 128) {
if (a.hidden_size % 2 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 1, 4, 64, 2, true, false>>(s, a);
else
r = moe_smoothquant_<trait_<in_type, out_type, 1, 2, 4, 64, 1, true, false>>(s, a);
}
else if(a.hidden_size <= 256) {
if (a.hidden_size % 4 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 1, 4, 64, 4, true, false>>(s, a);
else if (a.hidden_size % 2 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 2, 4, 64, 2, true, false>>(s, a);
else
r = moe_smoothquant_<trait_<in_type, out_type, 1, 4, 4, 64, 1, true, false>>(s, a);
}
else if(a.hidden_size <= 512) {
if (a.hidden_size % 8 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 1, 4, 64, 8, true, false>>(s, a);
else if (a.hidden_size % 4 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 2, 4, 64, 4, true, false>>(s, a);
else if (a.hidden_size % 2 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 4, 4, 64, 2, true, false>>(s, a);
else
r = moe_smoothquant_<trait_<in_type, out_type, 1, 8, 4, 64, 1, true, false>>(s, a);
}
else if(a.hidden_size <= 768) {
if (a.hidden_size % 4 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 3, 4, 64, 4, true, false>>(s, a);
else if (a.hidden_size % 2 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 6, 4, 64, 2, true, false>>(s, a);
else
r = moe_smoothquant_<trait_<in_type, out_type, 1,12, 4, 64, 1, true, false>>(s, a);
}
else if(a.hidden_size <= 1024) {
if (a.hidden_size % 8 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 1, 2, 128, 8, true, false>>(s, a);
else if (a.hidden_size % 4 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 2, 2, 128, 4, true, false>>(s, a);
else if (a.hidden_size % 2 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 4, 2, 128, 2, true, false>>(s, a);
else
r = moe_smoothquant_<trait_<in_type, out_type, 1, 4, 1, 256, 1, true, false>>(s, a);
}
else if(a.hidden_size <= 1536) {
if (a.hidden_size % 8 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 3, 4, 64, 8, true, false>>(s, a);
else if (a.hidden_size % 4 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 3, 2, 128, 4, true, false>>(s, a);
else if (a.hidden_size % 2 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 3, 1, 256, 2, true, false>>(s, a);
else
r = moe_smoothquant_<trait_<in_type, out_type, 1, 6, 1, 256, 1, true, false>>(s, a);
}
else if(a.hidden_size <= 2048) {
if (a.hidden_size % 8 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 1, 1, 256, 8, true, false>>(s, a);
else if (a.hidden_size % 4 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 2, 1, 256, 4, true, false>>(s, a);
else if (a.hidden_size % 2 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 4, 1, 256, 2, true, false>>(s, a);
else
r = moe_smoothquant_<trait_<in_type, out_type, 1, 8, 1, 256, 1, true, false>>(s, a);
}
else if(a.hidden_size <= 3072) {
if (a.hidden_size % 8 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 3, 1, 128, 8, true, false>>(s, a);
else if (a.hidden_size % 4 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 3, 1, 256, 4, true, false>>(s, a);
else if (a.hidden_size % 2 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 6, 1, 256, 2, true, false>>(s, a);
else
r = moe_smoothquant_<trait_<in_type, out_type, 1, 3, 1, 1024, 1, true, false>>(s, a);
}
else if(a.hidden_size <= 4096) {
if (a.hidden_size % 8 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 2, 1, 256, 8, true, false>>(s, a);
else if (a.hidden_size % 4 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 4, 1, 256, 4, true, false>>(s, a);
else if (a.hidden_size % 2 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 2, 1, 1024, 2, true, false>>(s, a);
else
r = moe_smoothquant_<trait_<in_type, out_type, 1, 4, 1, 1024, 1, true, false>>(s, a);
}
else if(a.hidden_size > 4096) {
if (a.hidden_size % 8 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 2, 1, 256, 8, true, true>>(s, a);
else if (a.hidden_size % 4 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 4, 1, 256, 4, true, true>>(s, a);
else if (a.hidden_size % 2 == 0)
r = moe_smoothquant_<trait_<in_type, out_type, 1, 2, 1, 1024, 2, true, true>>(s, a);
else
r = moe_smoothquant_<trait_<in_type, out_type, 1, 4, 1, 1024, 1, true, true>>(s, a);
}
return r;
// clang-format on
}
float moe_smoothquant(moe_smoothquant_traits t,
moe_smoothquant_args a,
const ck_tile::stream_config& s)
{
if(t.in_type.compare("fp16") == 0 && t.out_type == "int8")
{
return moe_smoothquant_dispatch<ck_tile::fp16_t, ck_tile::int8_t>(t, a, s);
}
else if(t.in_type.compare("fp16") == 0 && t.out_type == "fp8")
{
return moe_smoothquant_dispatch<ck_tile::fp16_t, ck_tile::fp8_t>(t, a, s);
}
else if(t.in_type.compare("bf16") == 0 && t.out_type == "int8")
{
return moe_smoothquant_dispatch<ck_tile::bf16_t, ck_tile::int8_t>(t, a, s);
}
else if(t.in_type.compare("bf16") == 0 && t.out_type == "fp8")
{
return moe_smoothquant_dispatch<ck_tile::bf16_t, ck_tile::fp8_t>(t, a, s);
}
else
throw std::runtime_error("Without supported instances!");
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
#include "moe_smoothquant.hpp"
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = rmsnorm2d_fwd_args;
using A = moe_smoothquant_args;
template <typename DataType_,
template <typename InputType_,
typename OutputType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_>
using trait_ = rmsnorm2d_fwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveInvRms_,
kTwoPass_>;
using trait_ = moe_smoothquant_traits_<InputType_,
OutputType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kTwoPass_>;
template <typename Traits_>
float rmsnorm2d_fwd_(const S& s, A a)
float moe_smoothquant_(const S& s, A a)
{
using DataType = typename Traits_::DataType;
using InputType = typename Traits_::InputType;
using OutputType = typename Traits_::OutputType;
using PipelineProblem =
ck_tile::Rmsnorm2dFwdPipelineProblem<typename RmsnormTypeConfig<DataType>::XDataType,
typename RmsnormTypeConfig<DataType>::GammaDataType,
typename RmsnormTypeConfig<DataType>::ComputeDataType,
typename RmsnormTypeConfig<DataType>::YDataType,
typename RmsnormTypeConfig<DataType>::InvRmsDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveInvRms,
Traits_::kTwoPass>;
using PipelineProblem = ck_tile::SmoothquantPipelineProblem<
typename MoeSmoothquantTypeConfig<InputType, OutputType>::XDataType,
typename MoeSmoothquantTypeConfig<InputType, OutputType>::SmoothScaleDataType,
typename MoeSmoothquantTypeConfig<InputType, OutputType>::ComputeDataType,
typename MoeSmoothquantTypeConfig<InputType, OutputType>::YScaleDataType,
typename MoeSmoothquantTypeConfig<InputType, OutputType>::QYDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kTwoPass>;
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline>;
using Kernel = ck_tile::MoeSmoothquant<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
......
#include "ck_tile/host.hpp"
#include "moe_smoothquant.hpp"
#include <cstring>
#include <set>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::int8_t>()
{
// due to rounding, int8 quantization might have 1 abs error
double rtol = 1;
double atol = 1;
return ck_tile::make_tuple(rtol, atol);
}
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("t", "3328", "tokens dimension")
.insert("h", "4096", "hidden_size dimension")
.insert("e", "32", "experts")
.insert("k", "5", "topk")
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec_i", "fp16", "input precision, fp16/bf16")
.insert("prec_o", "int8", "precision, int8/fp8")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename InputType, typename OutputType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t tokens = arg_parser.get_int("t");
ck_tile::index_t hidden_size = arg_parser.get_int("h");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = hidden_size;
ck_tile::index_t experts = arg_parser.get_int("e");
ck_tile::index_t topk = arg_parser.get_int("k");
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(stride >= hidden_size);
using TypeConfig = MoeSmoothquantTypeConfig<InputType, OutputType>;
using XDataType = typename TypeConfig::XDataType;
using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify
ck_tile::HostTensor<XDataType> x_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<SmoothScaleDataType> smscale_host({experts * hidden_size});
ck_tile::HostTensor<ck_tile::index_t> topk_ids_host({tokens, topk});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({topk * tokens}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({topk * tokens}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({topk * tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({topk * tokens, hidden_size}, {stride, 1});
topid_unique_gen<ck_tile::index_t>(topk_ids_host.mData, tokens, topk, experts, 11937);
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem topk_ids_buf(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
smscale_buf.ToDevice(smscale_host.data());
topk_ids_buf.ToDevice(topk_ids_host.data());
std::cout << "[" << prec_i << "-" << prec_o << "]"
<< " tokens:" << tokens << ", hidden_size:" << hidden_size << ", stride:" << stride
<< ", experts:" << experts << ", topk:" << topk << std::flush;
moe_smoothquant_traits traits{prec_i, prec_o};
moe_smoothquant_args args{x_buf.GetDeviceBuffer(),
smscale_buf.GetDeviceBuffer(),
topk_ids_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(),
tokens,
hidden_size,
experts,
topk,
stride,
stride};
float ave_time = moe_smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = sizeof(XDataType) * tokens * hidden_size +
sizeof(SmoothScaleDataType) * topk * hidden_size +
sizeof(YScaleDataType) * topk * tokens +
sizeof(QYDataType) * topk * tokens * hidden_size;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
bool pass = true;
if(do_validation)
{
using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({topk * tokens, hidden_size}, {stride, 1});
// smooth outlier
{
auto f = [&](auto i_token) {
for(int i_topk = 0; i_topk < topk; i_topk++)
{
auto i_expert = topk_ids_host(i_token, i_topk);
for(int i_h = 0; i_h < hidden_size; ++i_h)
{
auto v_smscale = ck_tile::type_convert<ComputeDataType>(
smscale_host(i_expert * hidden_size + i_h));
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(i_token, i_h));
// y_host(i_token * topk + i_topk, i_h) = v_x * v_smscale;
y_host(i_topk * tokens + i_token, i_h) = v_x * v_smscale;
}
}
};
ck_tile::make_ParallelTensorFunctor(f, tokens)(std::thread::hardware_concurrency());
}
// yscale
{
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({topk * tokens});
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
ck_tile::reference_reduce<ComputeDataType, ComputeDataType, YDataType>(
y_host, y_rowwise_amax_host, ReduceAmax{});
auto op = [](const auto& v0) {
return v0 /
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
};
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
y_rowwise_amax_host, yscale_host_ref, op);
yscale_buf.FromDevice(yscale_host_dev.mData.data());
auto [rtol, atol] = get_elimit<YScaleDataType>();
pass &= ck_tile::check_err(yscale_host_dev,
yscale_host_ref,
std::string("yscale Error: Incorrect results!"),
rtol,
atol);
}
// rowwise quantization
{
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
y_host, yscale_host_ref, qy_host_ref);
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == hidden_size)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
std::string("qy Error: Incorrect results!"),
rtol,
atol);
}
else
{
for(int i_r = 0; i_r < topk * tokens; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
qy_host_dev.begin() + i_r * stride +
hidden_size);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
qy_host_ref.begin() + i_r * stride +
hidden_size);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string prec_i = arg_parser.get_str("prec_i");
const std::string prec_o = arg_parser.get_str("prec_o");
if(prec_i == "fp16" && prec_o == "int8")
{
return run<ck_tile::half_t, ck_tile::int8_t>(arg_parser) ? 0 : -2;
}
else if(prec_i == "fp16" && prec_o == "fp8")
{
return run<ck_tile::half_t, ck_tile::fp8_t>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "int8")
{
return run<ck_tile::bf16_t, ck_tile::int8_t>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "fp8")
{
return run<ck_tile::bf16_t, ck_tile::fp8_t>(arg_parser) ? 0 : -2;
}
return -3;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/smoothquant.hpp"
#include <string>
template <typename InputType, typename OutputType>
struct MoeSmoothquantTypeConfig
{
using XDataType = InputType;
using SmoothScaleDataType = float;
using YScaleDataType = float;
using QYDataType = OutputType;
using ComputeDataType = float;
};
// runtime args
struct moe_smoothquant_args : public ck_tile::MoeSmoothquantHostArgs
{
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename InputType_,
typename OutputType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kTwoPass_>
struct moe_smoothquant_traits_
{
using InputType = ck_tile::remove_cvref_t<InputType_>;
using OutputType = ck_tile::remove_cvref_t<OutputType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kTwoPass = kTwoPass_;
};
template <typename Traits_>
float moe_smoothquant_(const ck_tile::stream_config& s, moe_smoothquant_args a);
// This is the public API, will be generated by script
struct moe_smoothquant_traits
{
std::string in_type; // input type
std::string out_type; // output type
};
float moe_smoothquant(moe_smoothquant_traits, moe_smoothquant_args, const ck_tile::stream_config&);
EXE=build/bin/tile_example_moe_smoothquant
$EXE -t=1 -h=1 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=80 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=128 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=144 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=168 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=184 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=256 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=288 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=344 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=376 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=448 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=512 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=924 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=1024 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=1078 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=1996 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=4080 -v=1 -prec=bf16 -repeat=1000
$EXE -t=700 -h=80 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=128 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=144 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=168 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=184 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=256 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=288 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=344 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=376 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=448 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=512 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=924 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=1024 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=1078 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=1996 -v=1 -prec=fp16 -repeat=1000
$EXE -t=700 -h=4080 -v=1 -prec=fp16 -repeat=1000
\ No newline at end of file
#!/bin/sh
EXE=build/bin/tile_example_moe_smoothquant
for pr_i in "fp16" "bf16" ; do
for pr_o in "int8" "fp8" ; do
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=99 -h=13
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=17 -h=16
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=1 -h=100
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=4 -h=128
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=80 -h=127
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=22 -h=255 -stride=256
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=7 -h=599
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=19 -h=512
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=33 -h=313 -stride=1000
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=11 -h=510
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=171 -h=676 -stride=818
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=91 -h=636
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=12 -h=768 -stride=800
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=100 -h=766 -stride=812
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=31 -h=1024
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=64 -h=1000 -stride=1004
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=8 -h=1501
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=3 -h=1826
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=5 -h=2040
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=7 -h=2734
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=1 -h=3182
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=9 -h=4096
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=3 -h=8192
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=1 -h=10547
$EXE -prec_i=$pr_i -prec_o=$pr_o -t=3 -h=17134
done
done
set(TILE_EXAPMLE_FUSED_MOE "tile_example_fused_moe")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding ${TILE_EXAPMLE_FUSED_MOE}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${TILE_EXAPMLE_FUSED_MOE} EXCLUDE_FROM_ALL main.cpp)
target_include_directories(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${INSTANCE_SRCS})
set(TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_BUFFER_LOAD_AGPR=1) # TODO: enable load to a
list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=4) # rta
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1)
# list(APPEND TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(${TILE_EXAPMLE_FUSED_MOE} PRIVATE ${TILE_EXAPMLE_FUSED_MOE_COMPILE_OPTIONS})
# fused-moe
Implementing the fused-moe block operator using ck-tile. This is a scatter/gather-group-gemm based solution, similiar to that of [vllm moe](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), but we introduce more kernel fusion to boost performance
![](misc/moe-0.png)
The benifit of this fused-moe:
* 1.5~2x perf boost compared with current vllm solution
* zero workspace to reduce memory footprint
* much less kernel instance, easy to maintain
# Implementation and feature support
## NOTES:
currently gate+up in fp16 case will very easily cause accumulator overflow the fp16 max(65504), hence result in INF. Please use BF16 for gate+up case, API side will have no check for this.
## moe-sorting
this is a common pre-process step before the actual moe-gemm. The purpose is to transform the moe loop over from token-by-token to expert-by-expert, make sure very workgroup is working for a single expert (B matrix). Besides, we extend this op to do the zeroing of the output buffer(to be used for reduce buffer with atomic)
## moe-gemm
`moe-gemm` is a group-gemm based back-to-back gemm, where the row-id of input token comes from another buffer. Naive understanding of fused-moe is from token-by-token view as below picture:
![](misc/moe-1.png)
After `moe-sorting`, we can view this algorithm as expert-by-expert, as below:
![](misc/moe-2.png)
## optimization
summary of the key design of this fused-moe operator:
* fuse 2 group-gemm + activation + `topk-weight` multiply into single kernel, using atomic for 2nd gemm accumualation
* fuse buffer-zeroing in `moe-sorgin`, user no longer need call extra torch.zero() for the out buffer
* fused scatter-gather for row index(same as vllm)
* pre-shuffle B matric(weight) to maximize memory throughput. input(activation) keep original layout `[batch, hidden]`.
* extrem optimized pipeline using block-inline-asm(we call it `micro-kernel` or `uk`), while not breaking the *composable* design of ck
##
```
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr [max_num_tokens_padded]
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
```
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "fused_moesorting.hpp"
#include "fused_moegemm.hpp"
struct fused_moe_args
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token (no need to do zeroing)
const void* topk_ids_ptr; // [tokens, topk]
const void* topk_weight_ptr; // [tokens, topk]
void* sorted_token_ids_ptr; // [max_num_tokens_padded]
void* sorted_weight_ptr; // [max_num_tokens_padded]
void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
void* num_sorted_tiles_ptr; // [1]
ck_tile::index_t block_m; // block_m, used to devide the input
ck_tile::index_t hidden_size; // k
ck_tile::index_t intermediate_size; // n / TP, for Gate. and Up, Down is also this value
ck_tile::index_t num_tokens; // input number of tokens for current iteration
ck_tile::index_t num_experts; // number of groups
ck_tile::index_t topk; // need this?
ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
// This is the public API, will be generated by script
struct fused_moe_traits
{
std::string prec_i; // input precision
std::string prec_w; // weight precision
std::string prec_o; // output precision
std::string prec_st; // token scale data type
std::string prec_sw; // weight scale data type
std::string prec_sq; // smooth quant scale
std::string prec_kw; // topk-weight data type
int block_m;
int activation; // 0:gelu, 1:silu
int gate_only; // 0:g1u0, 1:g1u1
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include <string>
// this is only a convenient structure for creating an example
// this is not part of the host API
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig;
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{
using ADataType = ck_tile::bf16_t;
using GDataType = ck_tile::bf16_t;
using DDataType = ck_tile::bf16_t;
using AccDataType = float;
using ODataType = ck_tile::bf16_t;
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
using IndexDataType = ck_tile::index_t;
};
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, ST, SW, SQ, KW>
{
using ADataType = ck_tile::fp16_t;
using GDataType = ck_tile::fp16_t;
using DDataType = ck_tile::fp16_t;
using AccDataType = float;
using ODataType = ck_tile::fp16_t;
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
using IndexDataType = ck_tile::index_t;
};
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>
{
using ADataType = ck_tile::int8_t;
using GDataType = ck_tile::int8_t;
using DDataType = ck_tile::int8_t;
using AccDataType = int32_t;
using ODataType = ck_tile::bf16_t;
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
using GScaleDataType = ck_tile::remove_cvref_t<SW>;
using DScaleDataType = ck_tile::remove_cvref_t<SW>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
using IndexDataType = ck_tile::index_t;
};
// runtime args
struct fused_moegemm_args : public ck_tile::FusedMoeGemmHostArgs
{
};
// This is the public API, will be generated by script
struct fused_moegemm_traits
{
std::string prec_i; // input precision
std::string prec_w; // weight precision
std::string prec_o; // output precision
std::string prec_st; // token scale data type
std::string prec_sw; // weight scale data type
std::string prec_sq; // smooth quant scale
std::string prec_kw; // topk-weight data type
int block_m;
int activation; // 0:gelu, 1:silu
int gate_only; // 0:g1u0, 1:g1u1
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
float fused_moegemm(fused_moegemm_traits, fused_moegemm_args, const ck_tile::stream_config&);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment