Commit 019d4b7c authored by illsilin's avatar illsilin
Browse files

merge from public repo

parents 07307ea1 5063a39f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
#if 0
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 128, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false, true>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false, true>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false, true>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false, true>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "rmsnorm2d_fwd.hpp"
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = rmsnorm2d_fwd_args;
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_>
using trait_ = rmsnorm2d_fwd_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kSaveInvRms_,
kTwoPass_>;
template <typename Traits_>
float rmsnorm2d_fwd_(const S& s, A a)
{
using DataType = typename Traits_::DataType;
using PipelineProblem =
ck_tile::Rmsnorm2dFwdPipelineProblem<typename RmsnormTypeConfig<DataType>::XDataType,
typename RmsnormTypeConfig<DataType>::GammaDataType,
typename RmsnormTypeConfig<DataType>::ComputeDataType,
typename RmsnormTypeConfig<DataType>::YDataType,
typename RmsnormTypeConfig<DataType>::InvRmsDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kSaveInvRms,
Traits_::kTwoPass>;
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -8,27 +8,34 @@ ...@@ -8,27 +8,34 @@
#include "ck_tile/ops/rmsnorm2d.hpp" #include "ck_tile/ops/rmsnorm2d.hpp"
#include <string> #include <string>
template <typename DataType> template <typename InType,
typename OutType,
typename SmoothScaleDataType_,
typename YScaleDataType_>
struct RmsnormTypeConfig; struct RmsnormTypeConfig;
template <> template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
struct RmsnormTypeConfig<ck_tile::half_t> struct RmsnormTypeConfig<ck_tile::half_t, OutType, SmoothScaleDataType_, YScaleDataType_>
{ {
using XDataType = ck_tile::half_t; using XDataType = ck_tile::half_t;
using YDataType = ck_tile::half_t; using YDataType = OutType;
using GammaDataType = ck_tile::half_t; using GammaDataType = ck_tile::half_t;
using InvRmsDataType = ck_tile::half_t; using InvRmsDataType = ck_tile::half_t;
using ComputeDataType = float; using ComputeDataType = float;
using SmoothScaleDataType = SmoothScaleDataType_;
using YScaleDataType = YScaleDataType_;
}; };
template <> template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
struct RmsnormTypeConfig<ck_tile::bf16_t> struct RmsnormTypeConfig<ck_tile::bf16_t, OutType, SmoothScaleDataType_, YScaleDataType_>
{ {
using XDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t;
using YDataType = ck_tile::bf16_t; using YDataType = OutType;
using GammaDataType = ck_tile::bf16_t; using GammaDataType = ck_tile::bf16_t;
using InvRmsDataType = ck_tile::bf16_t; using InvRmsDataType = ck_tile::bf16_t;
using ComputeDataType = float; using ComputeDataType = float;
using SmoothScaleDataType = SmoothScaleDataType_;
using YScaleDataType = YScaleDataType_;
}; };
// runtime args // runtime args
...@@ -36,82 +43,24 @@ struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs ...@@ -36,82 +43,24 @@ struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs
{ {
}; };
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kSaveInvRms_,
bool kTwoPass_>
struct rmsnorm2d_fwd_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return total_warps * (warpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(warpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / warpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(warpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % warpSize == 0);
return ThreadPerBlock_N_ / warpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kTwoPass = kTwoPass_;
};
template <typename Traits_> template <typename Traits_>
float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a); float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a);
// This is the public API, will be generated by script // This is the public API, will be generated by script
struct rmsnorm2d_fwd_traits struct rmsnorm2d_fwd_traits
{ {
std::string data_type; std::string prec_i; // input precision
std::string prec_o; // output precision
// if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
// can set arbitrary(will skip check)
std::string prec_sm; // x-scale, used for [1*N] input smooth quant
std::string prec_sy; // y-scale, used for [M*1] output for next layer
bool save_rms; bool save_rms;
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
}; };
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&); float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&);
#!/bin/sh #!/bin/sh
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8"; do
for pr_i in "fp16" "bf16" ; do for pr_i in "fp16" "bf16" ; do
$EXE -prec=$pr_i -m=99 -n=13 for fadd in "0" "1"; do
$EXE -prec=$pr_i -m=17 -n=16 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
$EXE -prec=$pr_i -m=1 -n=100 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16
$EXE -prec=$pr_i -m=4 -n=128 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100
$EXE -prec=$pr_i -m=80 -n=127 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128
$EXE -prec=$pr_i -m=22 -n=255 -stride=256 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127
$EXE -prec=$pr_i -m=7 -n=599 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256
$EXE -prec=$pr_i -m=19 -n=512 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599
$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512
$EXE -prec=$pr_i -m=11 -n=510 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000
$EXE -prec=$pr_i -m=171 -n=676 -stride=818 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510
$EXE -prec=$pr_i -m=91 -n=636 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818
$EXE -prec=$pr_i -m=12 -n=768 -stride=800 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636
$EXE -prec=$pr_i -m=100 -n=766 -stride=812 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800
$EXE -prec=$pr_i -m=31 -n=1024 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812
$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024
$EXE -prec=$pr_i -m=8 -n=1501 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004
$EXE -prec=$pr_i -m=3 -n=1826 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501
$EXE -prec=$pr_i -m=5 -n=2040 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826
$EXE -prec=$pr_i -m=7 -n=2734 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040
$EXE -prec=$pr_i -m=1 -n=3182 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec=$pr_i -m=9 -n=4096 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec=$pr_i -m=3 -n=8192 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec=$pr_i -m=1 -n=10547 $EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
$EXE -prec=$pr_i -m=3 -n=17134 #$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done done
...@@ -63,17 +63,17 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -63,17 +63,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
int warmup = arg_parser.get_int("warmup"); int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat"); int repeat = arg_parser.get_int("repeat");
assert(stride >= n); assert(x_stride >= n);
using XDataType = DataType; using XDataType = DataType;
using XScaleDataType = float; using SmoothScaleDataType = float;
using YScaleDataType = float; using YScaleDataType = float;
using QYDataType = ck_tile::int8_t; using QYDataType = ck_tile::int8_t;
using ComputeDataType = float; using ComputeDataType = float;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n}); ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
...@@ -82,15 +82,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -82,15 +82,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
xscale_buf.ToDevice(xscale_host.data()); smscale_buf.ToDevice(smscale_host.data());
constexpr bool kTwoPass = true; constexpr bool kTwoPass = true;
...@@ -101,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -101,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>; using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::SmoothquantPipelineProblem<XDataType, using Problem = ck_tile::SmoothquantPipelineProblem<XDataType,
XScaleDataType, SmoothScaleDataType,
ComputeDataType, ComputeDataType,
YScaleDataType, YScaleDataType,
QYDataType, QYDataType,
...@@ -115,7 +115,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -115,7 +115,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Kernel = ck_tile::Smoothquant<Pipeline>; using Kernel = ck_tile::Smoothquant<Pipeline>;
ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(), ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(),
xscale_buf.GetDeviceBuffer(), smscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
m, m,
...@@ -142,16 +142,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -142,16 +142,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
// smooth outlier // smooth outlier
{ {
auto f = [&](auto n_) { auto f = [&](auto n_) {
auto v_xscale = ck_tile::type_convert<ComputeDataType>(xscale_host(n_)); auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
for(int m_ = 0; m_ < m; ++m_) for(int m_ = 0; m_ < m; ++m_)
{ {
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_)); auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_xscale; y_host(m_, n_) = v_x * v_smscale;
} }
}; };
ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "smoothquant.hpp" #include "smoothquant.hpp"
...@@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a) ...@@ -35,7 +35,7 @@ float smoothquant_(const S& s, A a)
using PipelineProblem = ck_tile::SmoothquantPipelineProblem< using PipelineProblem = ck_tile::SmoothquantPipelineProblem<
typename SmoothquantTypeConfig<DataType>::XDataType, typename SmoothquantTypeConfig<DataType>::XDataType,
typename SmoothquantTypeConfig<DataType>::XScaleDataType, typename SmoothquantTypeConfig<DataType>::SmoothScaleDataType,
typename SmoothquantTypeConfig<DataType>::ComputeDataType, typename SmoothquantTypeConfig<DataType>::ComputeDataType,
typename SmoothquantTypeConfig<DataType>::YScaleDataType, typename SmoothquantTypeConfig<DataType>::YScaleDataType,
typename SmoothquantTypeConfig<DataType>::QYDataType, typename SmoothquantTypeConfig<DataType>::QYDataType,
......
...@@ -67,14 +67,14 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -67,14 +67,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
using TypeConfig = SmoothquantTypeConfig<DataType>; using TypeConfig = SmoothquantTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType; using XDataType = typename TypeConfig::XDataType;
using XScaleDataType = typename TypeConfig::XScaleDataType; using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType; using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType; using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n}); ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1}); ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
...@@ -83,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -83,15 +83,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1}); ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host); ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host); ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data()); x_buf.ToDevice(x_host.data());
xscale_buf.ToDevice(xscale_host.data()); smscale_buf.ToDevice(smscale_host.data());
std::cout << "[" << data_type << "]" std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride
...@@ -100,7 +100,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -100,7 +100,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
smoothquant_traits traits{data_type}; smoothquant_traits traits{data_type};
smoothquant_args args{x_buf.GetDeviceBuffer(), smoothquant_args args{x_buf.GetDeviceBuffer(),
xscale_buf.GetDeviceBuffer(), smscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(), yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(), qy_buf.GetDeviceBuffer(),
m, m,
...@@ -111,7 +111,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -111,7 +111,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
float ave_time = smoothquant( float ave_time = smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XScaleDataType) * n + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(SmoothScaleDataType) * n +
sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n; sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time;
...@@ -126,16 +126,16 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -126,16 +126,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
// smooth outlier // smooth outlier
{ {
auto f = [&](auto n_) { auto f = [&](auto n_) {
auto v_xscale = ck_tile::type_convert<ComputeDataType>(xscale_host(n_)); auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
for(int m_ = 0; m_ < m; ++m_) for(int m_ = 0; m_ < m; ++m_)
{ {
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_)); auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_xscale; y_host(m_, n_) = v_x * v_smscale;
} }
}; };
ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
std::thread::hardware_concurrency()); std::thread::hardware_concurrency());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment