"Src/vscode:/vscode.git/clone" did not exist on "cb8f00187c379a1a270bcedec54a6d649a02587c"
Unverified Commit fbd65454 authored by rocking's avatar rocking Committed by GitHub
Browse files

[Ck_tile] smoothquant (#1617)



* fix compile error

* fix typo of padding

* Add smoothquant op

* Add smoothquant instance library

* refine type

* add test script

* Re-generate smoothquant.hpp

* Always use 'current year' in copyright

* use Generic2dBlockShape instead

* Add vector = 8 instance back

* Find exe path automatically

* Simplify the api condition

* Remove debugging code

* update year

* Add blank line between function declaration

* explicitly cast return value to dim3

* refine return value

* Fix default warmup and repeat value

* Add comment

* refactor sommthquant cmake

* Add README

* Fix typo

---------
Co-authored-by: default avatarPo Yen, Chen <PoYen.Chen@amd.com>
parent 550248de
# run from top of ck folder
EXE=build/bin/tile_example_layernorm2d_fwd
#!/bin/sh
EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)"
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000
......
#!/bin/sh
# call from top of CK folder
EXE=./build/bin/tile_example_layernorm2d_fwd
EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8"; do
for pr_i in "fp16" "bf16" ; do
......
......@@ -69,7 +69,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using WarpTile = ck_tile::sequence<1, 64>;
using Vector = ck_tile::sequence<1, 1>;
using Shape = ck_tile::Rmsnorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem<XDataType,
GammaDataType,
ComputeDataType,
......
......@@ -28,7 +28,6 @@ float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/,
rmsnorm2d_fwd_args a,
const ck_tile::stream_config& s)
{
#if 1
float r = -1;
// clang-format off
// rm rn tm tn vn pd rms 2p
......@@ -128,16 +127,12 @@ float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/,
r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, true>>(s, a);
}
return r;
#else
return rmsnorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 4, true, false, false>>(s, a);
#endif
// clang-format on
}
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile::stream_config& s)
{
float r = -1;
if(t.data_type.compare("fp16") == 0)
{
return rmsnorm2d_fwd_b16_<ck_tile::fp16_t>(t, a, s);
......@@ -146,8 +141,6 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile:
{
return rmsnorm2d_fwd_b16_<ck_tile::bf16_t>(t, a, s);
}
if(r < 0)
else
throw std::runtime_error("Without supported instances!");
return r;
}
......@@ -97,7 +97,7 @@ struct rmsnorm2d_fwd_traits_
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Rmsnorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveInvRms = kSaveInvRms_;
......
# run from top of ck folder
EXE=build/bin/tile_rmsnorm2d_fwd
#!/bin/sh
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
......
#!/bin/sh
# call from top of CK folder
EXE=./build/bin/tile_rmsnorm2d_fwd
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
for pr_i in "fp16" "bf16" ; do
$EXE -prec=$pr_i -m=99 -n=13
......
......@@ -18,7 +18,7 @@ struct AddRmsnormRdquantTypeConfig<ck_tile::half_t>
using BDataType = ck_tile::half_t;
using GammaDataType = ck_tile::half_t;
using XDataType = ck_tile::half_t;
using YScaleDataType = ck_tile::half_t;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
};
......@@ -30,7 +30,7 @@ struct AddRmsnormRdquantTypeConfig<ck_tile::bf16_t>
using BDataType = ck_tile::bf16_t;
using GammaDataType = ck_tile::bf16_t;
using XDataType = ck_tile::bf16_t;
using YScaleDataType = ck_tile::bf16_t;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
};
......@@ -101,7 +101,7 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::AddRmsnorm2dRdquantShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveX = kSaveX_;
......
......@@ -66,7 +66,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using BDataType = DataType;
using GammaDataType = DataType;
using XDataType = DataType;
using YScaleDataType = DataType;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
......@@ -99,12 +99,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
constexpr bool kThreePass = true;
using BlockWarps = ck_tile::sequence<2, 2>;
using BlockTile = ck_tile::sequence<2, 128>;
using BlockWarps = ck_tile::sequence<4, 1>;
using BlockTile = ck_tile::sequence<4, 128>;
using WarpTile = ck_tile::sequence<1, 64>;
using Vector = ck_tile::sequence<1, 1>;
using Shape = ck_tile::AddRmsnorm2dRdquantShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem<ADataType,
BDataType,
GammaDataType,
......
......@@ -28,7 +28,6 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/,
add_rmsnorm2d_rdquant_fwd_args a,
const ck_tile::stream_config& s)
{
#if 1
float r = -1;
// clang-format off
// rm rn tm tn vn pd x 3p
......@@ -128,9 +127,6 @@ float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/,
r = add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, true, true>>(s, a);
}
return r;
#else
return add_rmsnorm2d_rdquant_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, true, false>>(s, a);
#endif
// clang-format on
}
......@@ -139,7 +135,6 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
const ck_tile::stream_config& s)
{
float r = -1;
// Only support instance of save_x == true for now
assert(t.save_x);
if(t.data_type.compare("fp16") == 0)
......@@ -150,8 +145,6 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
{
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::bf16_t>(t, a, s);
}
if(r < 0)
else
throw std::runtime_error("Without supported instances!");
return r;
}
# run from top of ck folder
EXE=build/bin/tile_add_rmsnorm2d_rdquant_fwd
#!/bin/sh
EXE="$(find . -name tile_add_rmsnorm2d_rdquant_fwd -type f | head -n 1)"
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
......
#!/bin/sh
# call from top of CK folder
EXE=./build/bin/tile_add_rmsnorm2d_rdquant_fwd
EXE="$(find . -name tile_add_rmsnorm2d_rdquant_fwd -type f | head -n 1)"
for pr_i in "fp16" "bf16" ; do
$EXE -prec=$pr_i -m=99 -n=13
......
function (add_smoothquant_example TARGET_NAME MAIN_SRC)
message("adding ${TARGET_NAME}")
# not using add_example_executable() to add target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC})
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
foreach(source IN LISTS ARGN)
list(APPEND INSTANCE_SRCS ${source})
endforeach()
target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS})
set(COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
endfunction(add_smoothquant_example TARGET_NAME MAIN_SRC)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS})
add_smoothquant_example(tile_example_smoothquant example_smoothquant.cpp)
# smoothquant
This folder contains example for smoothquant using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_smoothquant -j
```
This will result in an executable `build/bin/tile_smoothquant`
## cmdline
```
args:
-m m dimension (default:3328)
-n m dimension (default:4096)
-v cpu validation or not (default:1)
-prec precision (default:fp16)
```
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/smoothquant.hpp"
#include <cstring>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::int8_t>()
{
// due to rounding, int8 quantization might have 1 abs error
double rtol = 1;
double atol = 1;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "0", "cold iter")
.insert("repeat", "1", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = n;
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(stride >= n);
using XDataType = DataType;
using XScaleDataType = float;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
ck_tile::HostTensor<XScaleDataType> xscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<XScaleDataType>{1e-3, .5f}(xscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
xscale_buf.ToDevice(xscale_host.data());
constexpr bool kTwoPass = true;
using BlockWarps = ck_tile::sequence<2, 2>;
using BlockTile = ck_tile::sequence<2, 128>;
using WarpTile = ck_tile::sequence<1, 64>;
using Vector = ck_tile::sequence<1, 1>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::SmoothquantPipelineProblem<XDataType,
XScaleDataType,
ComputeDataType,
YScaleDataType,
QYDataType,
Shape,
true,
kTwoPass>;
using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass<Problem>;
using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass<Problem>;
using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::Smoothquant<Pipeline>;
ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(),
xscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(),
m,
n,
stride};
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto s = ck_tile::stream_config{nullptr, true, 1, warmup, repeat};
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
bool pass = true;
if(do_validation)
{
using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {stride, 1});
// smooth outlier
{
auto f = [&](auto n_) {
auto v_xscale = ck_tile::type_convert<ComputeDataType>(xscale_host(n_));
for(int m_ = 0; m_ < m; ++m_)
{
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_xscale;
}
};
ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())(
std::thread::hardware_concurrency());
}
// yscale
{
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({m});
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
ck_tile::reference_reduce<ComputeDataType, ComputeDataType, YDataType>(
y_host, y_rowwise_amax_host, ReduceAmax{});
auto op = [](const auto& v0) {
return v0 /
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
};
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
y_rowwise_amax_host, yscale_host_ref, op);
yscale_buf.FromDevice(yscale_host_dev.mData.data());
auto [rtol, atol] = get_elimit<YScaleDataType>();
pass &= ck_tile::check_err(yscale_host_dev,
yscale_host_ref,
std::string("yscale Error: Incorrect results!"),
rtol,
atol);
}
// rowwise quantization
{
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
y_host, yscale_host_ref, qy_host_ref);
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(stride == n)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
std::string("qy Error: Incorrect results!"),
rtol,
atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
qy_host_dev.begin() + i_r * stride + n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
qy_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
/*else if(data_type == "bf16")
{
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
}*/
return -3;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
#if 0
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true, false>>(const S&, A);
#endif
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 128, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment