Unverified Commit fbd65454 authored by rocking's avatar rocking Committed by GitHub
Browse files

[Ck_tile] smoothquant (#1617)



* fix compile error

* fix typo of padding

* Add smoothquant op

* Add smoothquant instance library

* refine type

* add test script

* Re-generate smoothquant.hpp

* Always use 'current year' in copyright

* use Generic2dBlockShape instead

* Add vector = 8 instance back

* Find exe path automatically

* Simplify the api condition

* Remove debugging code

* update year

* Add blank line between function declaration

* explicitly cast return value to dim3

* refine return value

* Fix default warmup and repeat value

* Add comment

* refactor sommthquant cmake

* Add README

* Fix typo

---------
Co-authored-by: default avatarPo Yen, Chen <PoYen.Chen@amd.com>
parent 550248de
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename Problem_, typename Policy_ = SmoothquantPipelineDefaultPolicy>
struct SmoothquantPipelineTwoPass
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using XScaleDataType = ck_tile::remove_cvref_t<typename Problem::XScaleDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using QYDataType = ck_tile::remove_cvref_t<typename Problem::QYDataType>;
using YScaleDataType = ck_tile::remove_cvref_t<typename Problem::YScaleDataType>;
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM
static constexpr bool kPadN = Problem::kPadN;
static constexpr const char* name = []() {
if constexpr(kNeedCrossWarpSync)
return "bpr_tp"; // block per row
else
return "wpr_tp"; // warp per row
}();
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename XWindow, typename XScaleWindow, typename QYWindow, typename YScaleWindow>
CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
const XScaleWindow& xscale_window_,
YScaleWindow& yscale_window,
QYWindow& qy_window,
ck_tile::index_t row_size,
void* smem) const
{
auto x_window =
make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
auto xscale_window = make_tile_window(
xscale_window_, Policy::template MakeXScaleBlockTileDistribution<Problem>());
static constexpr index_t Block_N = Problem::BlockShape::Block_N;
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
auto reduce_absmax_func = ReduceOp::AbsMax{};
auto reduce_max_func = ReduceOp::Max{};
auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
auto block_reduce2d_cross_warp_sync =
Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
auto absmax = block_reduce2d.template MakeYBlockTile<XTensorType>();
set_tile(absmax, reduce_absmax_func.GetIdentityValue<ComputeDataType>());
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window);
const auto y = tile_elementwise_in(
[&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
},
x,
xscale);
block_reduce2d(y, absmax, reduce_absmax_func);
move_tile_window(x_window, {0, Block_N});
move_tile_window(xscale_window, {Block_N});
}
// compute absmax, cross-lane->cross-warp
block_reduce2d_sync(absmax, reduce_max_func);
block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func);
// ex: yscale = absmax / 127 if int8
auto yscale = tile_elementwise_in(
[&](const auto& v_) {
return v_ / type_convert<ComputeDataType>(numeric<QYDataType>::max());
},
absmax);
store_tile(yscale_window, cast_tile<YScaleDataType>(yscale));
// reverse read x to reuse cache
ck_tile::index_t stride_to_right_most_window =
row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
move_tile_window(x_window, {0, -Block_N});
move_tile_window(xscale_window, {-Block_N});
move_tile_window(qy_window, {0, stride_to_right_most_window});
// recompute y and quantize y to qy
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
const auto x = load_tile(x_window);
const auto xscale = load_tile(xscale_window);
const auto y = tile_elementwise_in(
[&](const auto& a, const auto& b) {
return type_convert<ComputeDataType>(a) * type_convert<ComputeDataType>(b);
},
x,
xscale);
auto qy = make_static_distributed_tensor<QYDataType>(y.get_tile_distribution());
sweep_tile(qy, [&](auto idx) {
constexpr auto i_idx = make_tuple(idx[number<0>{}]);
auto qy_ = y[idx] / yscale[i_idx];
qy(idx) = saturates<QYDataType>{}(qy_);
});
store_tile(qy_window, qy);
move_tile_window(x_window, {0, -Block_N});
move_tile_window(xscale_window, {0, -Block_N});
move_tile_window(qy_window, {0, -Block_N});
}
}
};
} // namespace ck_tile
from datetime import datetime
import pathlib import pathlib
from pathlib import Path from pathlib import Path
import subprocess import subprocess
...@@ -8,8 +9,8 @@ NS = 'ck_tile' ...@@ -8,8 +9,8 @@ NS = 'ck_tile'
OPS = 'ops' OPS = 'ops'
OPS_COMMON = 'common' # common header will be duplicated into ops/* other module OPS_COMMON = 'common' # common header will be duplicated into ops/* other module
HEADER_COMMON = """// SPDX-License-Identifier: MIT HEADER_COMMON = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n // Copyright (c) 2018-{datetime.now().year}, Advanced Micro Devices, Inc. All rights reserved.\n
""" """
# aa/bb/cc/file.hpp -> (aa, bb, cc, file.hpp) # aa/bb/cc/file.hpp -> (aa, bb, cc, file.hpp)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment