Commit fb9f0757 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 4d914af3 c5ad2e80
# #!/bin/sh
EXE=./build/bin/tile_example_moe_sorting
$EXE -t=80 -e=17 -moe_buf_size=16
$EXE -t=111 -e=117 -moe_buf_size=4
$EXE -t=1000 -e=55 -moe_buf_size=1024
$EXE -t=99 -e=120 -moe_buf_size=10244
$EXE -t=175 -e=64 -k=8
$EXE -t=65 -e=8 -k=2
$EXE -t=1 -e=25
$EXE -t=31 -e=19 -k=15
$EXE -t=81 -e=37 -k=7
$EXE -t=23 -e=1 -k=1
$EXE -t=127 -e=99 -k=19
$EXE -t=71 -e=11 -k=11
$EXE -t=1 -e=1 -k=1
$EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13
\ No newline at end of file
...@@ -11,3 +11,5 @@ add_subdirectory(06_permute) ...@@ -11,3 +11,5 @@ add_subdirectory(06_permute)
add_subdirectory(09_topk_softmax) add_subdirectory(09_topk_softmax)
add_subdirectory(10_rmsnorm2d) add_subdirectory(10_rmsnorm2d)
add_subdirectory(11_add_rmsnorm2d_rdquant) add_subdirectory(11_add_rmsnorm2d_rdquant)
add_subdirectory(12_smoothquant)
add_subdirectory(13_moe_sorting)
...@@ -63,13 +63,15 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -63,13 +63,15 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define __gfx101__ #define __gfx101__
#endif #endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \
defined(__gfx10_3_generic__)
#define __gfx103__ #define __gfx103__
#endif #endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx11_generic__)
#define __gfx11__ #define __gfx11__
#endif #endif
#if defined(__gfx1200__) || defined(__gfx1201__) #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__ #define __gfx12__
#endif #endif
......
...@@ -93,12 +93,12 @@ __global__ void ...@@ -93,12 +93,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = const long_index_t a_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = const long_index_t b_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = const long_index_t e_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
...@@ -60,12 +60,12 @@ __global__ void ...@@ -60,12 +60,12 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = const long_index_t a_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = const long_index_t b_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = const long_index_t e_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -117,12 +117,12 @@ __global__ void ...@@ -117,12 +117,12 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = const long_index_t a_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = const long_index_t b_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = const long_index_t e_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
......
...@@ -98,12 +98,12 @@ __global__ void ...@@ -98,12 +98,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = const long_index_t a_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = const long_index_t b_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = const long_index_t c_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
...@@ -60,12 +60,12 @@ __global__ void ...@@ -60,12 +60,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = const long_index_t a_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = const long_index_t b_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = const long_index_t e_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
...@@ -155,12 +155,12 @@ __global__ void ...@@ -155,12 +155,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = const long_index_t a_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = const long_index_t b_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = const long_index_t e_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
...@@ -121,10 +121,10 @@ struct GridwiseTensorRearrange ...@@ -121,10 +121,10 @@ struct GridwiseTensorRearrange
__builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
// Global Memory // Global Memory
const index_t a_batch_offset = const index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
__builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const index_t c_batch_offset = const index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
__builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize()); p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize());
......
...@@ -9,7 +9,8 @@ ...@@ -9,7 +9,8 @@
// TODO: Add arch limitation // TODO: Add arch limitation
namespace ck { namespace ck {
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx11_generic__)
#define __gfx11__ #define __gfx11__
#endif #endif
/********************************WAVE32 MODE***********************************************/ /********************************WAVE32 MODE***********************************************/
...@@ -260,7 +261,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> ...@@ -260,7 +261,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
// gfx12 // gfx12
/********************************WAVE32 MODE***********************************************/ /********************************WAVE32 MODE***********************************************/
#if defined(__gfx1200__) || defined(__gfx1201__) #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__ #define __gfx12__
#endif #endif
......
...@@ -11,13 +11,15 @@ ...@@ -11,13 +11,15 @@
#define __gfx94__ #define __gfx94__
#endif #endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \
defined(__gfx10_3_generic__)
#define __gfx103__ #define __gfx103__
#endif #endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx11_generic__)
#define __gfx11__ #define __gfx11__
#endif #endif
#if defined(__gfx1200__) || defined(__gfx1201__) #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__ #define __gfx12__
#endif #endif
......
...@@ -170,7 +170,7 @@ CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in) ...@@ -170,7 +170,7 @@ CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
} }
else else
{ {
// NOT implemented static_assert(false, "The shuffle should always happen!");
} }
} }
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/reference/reference_permute.hpp" #include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <typename WeightType, typename IndexType = index_t>
CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
const HostTensor<WeightType>& weights,
HostTensor<IndexType>& p_sorted_token_ids,
HostTensor<WeightType>& sorted_weight,
HostTensor<IndexType>& sorted_expert_ids,
index_t& unit_cnt,
const index_t experts,
const index_t unit_size)
{
const index_t num_token = topk_ids.mDesc.get_lengths()[0];
const index_t topk = topk_ids.mDesc.get_lengths()[1];
std::vector<std::vector<IndexType>> expert_tokens(experts,
std::vector<IndexType>(unit_size, num_token));
std::vector<std::vector<WeightType>> expert_token_weights(
experts, std::vector<WeightType>(unit_size, 0));
std::vector<IndexType> expert_slices(experts, 1);
std::vector<IndexType> expert_slice_idxs(experts, 0);
for(index_t t = 0; t < num_token; t++)
{
for(index_t k = 0; k < topk; k++)
{
IndexType e = topk_ids(t, k);
WeightType w = weights(t, k);
index_t idx = expert_slice_idxs[e];
if(idx > expert_slices[e] * unit_size - 1)
{
expert_slices[e]++;
index_t new_size = expert_slices[e] * unit_size;
expert_tokens[e].resize(new_size);
expert_token_weights[e].resize(new_size);
for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++)
{
expert_tokens[e][i] = num_token;
expert_token_weights[e][i] = 0;
}
}
expert_tokens[e][idx] = t;
expert_token_weights[e][idx] = w;
expert_slice_idxs[e]++;
}
}
IndexType* out_tokens = p_sorted_token_ids.data();
WeightType* out_weights = sorted_weight.data();
IndexType* out_expert_id = sorted_expert_ids.data();
for(index_t e = 0; e < experts; e++)
{
memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size);
out_tokens += expert_slices[e] * unit_size;
memcpy(out_weights,
expert_token_weights[e].data(),
sizeof(WeightType) * expert_slices[e] * unit_size);
out_weights += expert_slices[e] * unit_size;
for(index_t s = 0; s < expert_slices[e]; s++)
{
out_expert_id[s] = e;
unit_cnt++;
}
out_expert_id += expert_slices[e];
}
unit_cnt *= unit_size;
return;
}
} // namespace ck_tile
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#pragma once #pragma once
#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
......
...@@ -9,15 +9,16 @@ ...@@ -9,15 +9,16 @@
namespace ck_tile { namespace ck_tile {
// host side args // host side args
// X = A + B, Y = Rmsnorm2d(X), QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
struct AddRmsnorm2dRdquantFwdHostArgs struct AddRmsnorm2dRdquantFwdHostArgs
{ {
const void* p_a; const void* p_a; // [m ,n], input, fp16/bf16
const void* p_b; const void* p_b; // [m ,n], input, fp16/bf16
const void* p_gamma; const void* p_gamma; // [1, n], gamma, prec same as input
void* p_x; void* p_x; // [m, n], output, p_a + p_b, fp16/bf16
void* p_yscale; void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of reuslt of rmsnorm2d(x)
void* p_qy; void* p_qy; // [m, n], output, result of quant tensor of rmsnorm2d(x) int8
float epsilon; float epsilon;
...@@ -90,7 +91,7 @@ struct AddRmsnorm2dRdquantFwd ...@@ -90,7 +91,7 @@ struct AddRmsnorm2dRdquantFwd
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
{ {
return integer_divide_ceil(hargs.m, Block_M); return dim3(integer_divide_ceil(hargs.m, Block_M));
} }
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
...@@ -170,7 +171,7 @@ struct AddRmsnorm2dRdquantFwd ...@@ -170,7 +171,7 @@ struct AddRmsnorm2dRdquantFwd
number<1>{}); number<1>{});
const auto tmp2_ = const auto tmp2_ =
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadM>{}); pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadN>{});
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0}); return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
}(); }();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
// clang-format off
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+
| |
+<-- <WarpPerBlock_N(2)> -->+
Warp_N
+--------------+--------------+--------------+--------------+----+----------------+
Warp_M | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <WarpPerBlock_M(2)> |
| wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M
| | |
+ + |
| | | v
+--------------+--------------+--------------+--------------+ +
each Warp-tile (e.g 16 thrd per row)
Vector_N (contiguous pixels each thrd holds along N, or vector size)
+-----------+-----------+-----------+-----------+-----------+
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
+-----------+-----------+-----------+-----------+-----------+
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
*/
template <typename BlockTile_, // block size, seq<M, N>
typename WarpPerBlock_, // num warps along seq<M, N>
typename WarpTile_, // warp size, seq<M, N>
typename Vector_, // contiguous pixels(vector size) along seq<M, N>
index_t BlockSize_ =
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
struct AddRmsnorm2dRdquantShape
{
// block size
static constexpr index_t Block_M = BlockTile_::at(number<0>{});
static constexpr index_t Block_N = BlockTile_::at(number<1>{});
// num warps along seq<M, N>, within each block
static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{});
static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{});
// warp size
static constexpr index_t Warp_M = WarpTile_::at(number<0>{});
static constexpr index_t Warp_N = WarpTile_::at(number<1>{});
static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0);
static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0);
// repeat of each thread along seq<M, N>
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
// vector size along seq<M, N>
static constexpr index_t Vector_M = Vector_::at(number<0>{});
static constexpr index_t Vector_N = Vector_::at(number<1>{});
static_assert(Warp_M % Vector_M == 0);
static_assert(Warp_N % Vector_N == 0);
// num of threads along seq<M, N>, within each warp
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
static constexpr index_t BlockSize = BlockSize_;
};
} // namespace ck_tile
...@@ -26,6 +26,7 @@ struct AddRmsnorm2dRdquantFwdPipelineDefaultPolicy ...@@ -26,6 +26,7 @@ struct AddRmsnorm2dRdquantFwdPipelineDefaultPolicy
sequence<1, 1, 2, 2>, sequence<1, 1, 2, 2>,
sequence<0, 3, 0, 3>>{}); sequence<0, 3, 0, 3>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeGammaBlockTileDistribution() CK_TILE_DEVICE static constexpr auto MakeGammaBlockTileDistribution()
{ {
......
...@@ -38,9 +38,7 @@ namespace ck_tile { ...@@ -38,9 +38,7 @@ namespace ck_tile {
template <typename BlockTile_, // block size, seq<M, N> template <typename BlockTile_, // block size, seq<M, N>
typename WarpPerBlock_, // num warps along seq<M, N> typename WarpPerBlock_, // num warps along seq<M, N>
typename WarpTile_, // warp size, seq<M, N> typename WarpTile_, // warp size, seq<M, N>
typename Vector_, // contiguous pixels(vector size) along seq<M, N> typename Vector_> // contiguous pixels(vector size) along seq<M, N>)>
index_t BlockSize_ =
warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})>
struct Generic2dBlockShape struct Generic2dBlockShape
{ {
// block size // block size
...@@ -68,10 +66,12 @@ struct Generic2dBlockShape ...@@ -68,10 +66,12 @@ struct Generic2dBlockShape
static_assert(Warp_M % Vector_M == 0); static_assert(Warp_M % Vector_M == 0);
static_assert(Warp_N % Vector_N == 0); static_assert(Warp_N % Vector_N == 0);
// num of threads along seq<M, N>, within each warp // num of threads along seq<M, N>, within each warp
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
static constexpr index_t ThreadPerBlock_M = Block_M / Repeat_M / Vector_M;
static constexpr index_t ThreadPerBlock_N = Block_N / Repeat_N / Vector_N;
static constexpr index_t BlockSize = BlockSize_; static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -8,17 +8,23 @@ ...@@ -8,17 +8,23 @@
namespace ck_tile { namespace ck_tile {
template <bool kPadM_, bool kPadN_, bool UseRawStore_ = true, bool UseMax3_ = false> template <bool kPadM_,
bool kPadN_,
bool UseSmoothInputScale_,
bool UseRawStore_ = true,
bool UseMax3_ = false>
struct DynamicQuantEpilogueTraits struct DynamicQuantEpilogueTraits
{ {
static constexpr bool kPadM = kPadM_; static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_; static constexpr bool UseSmoothInputScale = UseSmoothInputScale_;
static constexpr bool UseMax3 = UseMax3_; static constexpr bool UseRawStore = UseRawStore_;
static constexpr bool UseMax3 = UseMax3_;
}; };
// this epilogue just store out a M*N matrix, row major // this epilogue just store out a M*N matrix, row major
template <typename AccDataType_, template <typename AccDataType_,
typename XScaleDataType_,
typename YScaleDataType_, typename YScaleDataType_,
typename ODataType_, typename ODataType_,
typename BlockShape_, typename BlockShape_,
...@@ -26,17 +32,20 @@ template <typename AccDataType_, ...@@ -26,17 +32,20 @@ template <typename AccDataType_,
struct DynamicQuantEpilogueProblem struct DynamicQuantEpilogueProblem
{ {
using AccDataType = remove_cvref_t<AccDataType_>; using AccDataType = remove_cvref_t<AccDataType_>;
using XScaleDataType = remove_cvref_t<XScaleDataType_>;
using YScaleDataType = remove_cvref_t<YScaleDataType_>; using YScaleDataType = remove_cvref_t<YScaleDataType_>;
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape using BlockShape = remove_cvref_t<BlockShape_>; // can consum generic 2d shape
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
}; };
// TODO: we should put descriptor creation function into policy
template <typename Problem_, typename Policy_ = void> template <typename Problem_, typename Policy_ = void>
struct DynamicQuantEpilogue struct DynamicQuantEpilogue
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using XScaleDataType = remove_cvref_t<typename Problem::XScaleDataType>;
using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>; using YScaleDataType = remove_cvref_t<typename Problem::YScaleDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using BlockShape = remove_cvref_t<typename Problem::BlockShape>; using BlockShape = remove_cvref_t<typename Problem::BlockShape>;
...@@ -63,6 +72,33 @@ struct DynamicQuantEpilogue ...@@ -63,6 +72,33 @@ struct DynamicQuantEpilogue
return BlockReduce2dCrossWarpSync<P_>{}; return BlockReduce2dCrossWarpSync<P_>{};
} }
CK_TILE_DEVICE static constexpr auto MakeSmoothInputScaleTileDistribution()
{
using S = BlockShape;
#if 0
// don't remove this
// Note that if we set encoding purposely like this, you will result in compile fail
// TODO: x_scale create local-scratch to accept arbitrary acc input (with same length)
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<0, 1, 1>,
sequence<0, 0, 3>>{});
#else
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<S::WarpPerBlock_M, S::ThreadPerWarp_M>,
tuple<sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::Vector_N>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 1>, sequence<1, 2>>,
sequence<1, 1>,
sequence<0, 3>>{});
#endif
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
...@@ -71,8 +107,12 @@ struct DynamicQuantEpilogue ...@@ -71,8 +107,12 @@ struct DynamicQuantEpilogue
// TODO: this function assume store out vector size is the same as OAccTile last dimension size // TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ? // how do we fix this ?
template <typename ODramWindowTmp, typename YScaleWindow, typename OAccTile> template <typename ODramWindowTmp,
typename XScaleWindow,
typename YScaleWindow,
typename OAccTile>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const XScaleWindow& x_scale_window_,
YScaleWindow& y_scale_window, YScaleWindow& y_scale_window,
const OAccTile& o_acc_tile, const OAccTile& o_acc_tile,
void* smem) void* smem)
...@@ -80,6 +120,18 @@ struct DynamicQuantEpilogue ...@@ -80,6 +120,18 @@ struct DynamicQuantEpilogue
auto reduce = GetBlockReduce2d(); auto reduce = GetBlockReduce2d();
auto reduce_sync = GetBlockReduce2dSync(); auto reduce_sync = GetBlockReduce2dSync();
auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync(); auto reduce_crosswarp_sync = GetBlockReduce2dCrossWarpSync();
const auto x_scale_window =
make_tile_window(x_scale_window_, MakeSmoothInputScaleTileDistribution());
auto x_scale = load_tile(x_scale_window);
auto o_acc_tmp = o_acc_tile;
sweep_tile(o_acc_tmp, [&](auto idx) {
constexpr auto j_idx = make_tuple(idx[number<1>{}]);
const auto xs_ = type_convert<AccDataType>(x_scale[j_idx]);
o_acc_tmp(idx) = o_acc_tmp(idx) * xs_;
});
const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); }; const auto f_absmax = [](auto acc_, auto v_0_) { return max(acc_, abs(v_0_)); };
...@@ -87,10 +139,9 @@ struct DynamicQuantEpilogue ...@@ -87,10 +139,9 @@ struct DynamicQuantEpilogue
constexpr auto y_size_per_row = constexpr auto y_size_per_row =
OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at( OAccTile{}.get_tile_distribution().get_ys_to_d_descriptor().get_lengths().at(
number<1>{}); number<1>{});
// constexpr auto y_size_per_row = OAccTile::get_lengths()[number<1>{}];
if constexpr(UseMax3 && std::is_same_v<AccDataType, float> && y_size_per_row % 2 == 0) if constexpr(UseMax3 && std::is_same_v<AccDataType, float> && y_size_per_row % 2 == 0)
{ {
// fast max3 implementation // fast max3+abs implementation
const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) { const auto f_max3 = [](auto acc_, auto v_0_, auto v_1_) {
float rtn; float rtn;
asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)" asm volatile("v_max3_f32 %0, %1, abs(%2), abs(%3)"
...@@ -98,11 +149,11 @@ struct DynamicQuantEpilogue ...@@ -98,11 +149,11 @@ struct DynamicQuantEpilogue
: "v"(acc_), "v"(v_0_), "v"(v_1_)); : "v"(acc_), "v"(v_0_), "v"(v_1_));
return rtn; return rtn;
}; };
return reduce(o_acc_tile, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{}); return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_max3, sequence<1, 2>{});
} }
else else
{ {
return reduce(o_acc_tile, type_convert<AccDataType>(0), f_absmax); return reduce(o_acc_tmp, type_convert<AccDataType>(0), f_absmax);
} }
}(); }();
reduce_sync(row_absmax, f_absmax); reduce_sync(row_absmax, f_absmax);
...@@ -117,23 +168,20 @@ struct DynamicQuantEpilogue ...@@ -117,23 +168,20 @@ struct DynamicQuantEpilogue
store_tile(y_scale_window, cast_tile<YScaleDataType>(y_scale)); store_tile(y_scale_window, cast_tile<YScaleDataType>(y_scale));
auto o_acc_scaled_tile = sweep_tile(o_acc_tmp, [&](auto idx) {
make_static_distributed_tensor<AccDataType>(o_acc_tile.get_tile_distribution()); constexpr auto row_id = make_tuple(idx[number<0>{}]);
o_acc_tmp(idx) = o_acc_tmp[idx] / y_scale(row_id);
sweep_tile(o_acc_tile, [&](auto idx) {
constexpr auto row_id = make_tuple(idx[number<0>{}]);
o_acc_scaled_tile(idx) = o_acc_tile[idx] / y_scale(row_id);
}); });
// TODO: this is ugly // TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN)) if constexpr(UseRawStore && (kPadM || kPadN))
{ {
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_scaled_tile)); store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
buffer_store_fence(); buffer_store_fence();
} }
else else
{ {
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_scaled_tile)); store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tmp));
} }
} }
}; };
......
...@@ -230,7 +230,15 @@ struct PageBlockNavigator ...@@ -230,7 +230,15 @@ struct PageBlockNavigator
CK_TILE_HOST_DEVICE CK_TILE_HOST_DEVICE
DataType* get_block_ptr(index_t block_index) const DataType* get_block_ptr(index_t block_index) const
{ {
return physical_blocks + physical_block_indices[block_index] * block_stride + fixed_offset; if(block_index < num_blocks)
{
return physical_blocks + physical_block_indices[block_index] * block_stride +
fixed_offset;
}
else
{
return nullptr;
}
} }
CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment