Commit 66e61076 authored by aska-0096's avatar aska-0096
Browse files

Sanity pass.

parent 0c51a35e
...@@ -9,7 +9,7 @@ using ADataType = ck::half_t; ...@@ -9,7 +9,7 @@ using ADataType = ck::half_t;
using BDataType = int8_t; using BDataType = int8_t;
using ScaleDataType = ck::half_t; using ScaleDataType = ck::half_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = float; using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using ALayout = Row; using ALayout = Row;
......
...@@ -28,7 +28,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -28,7 +28,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
// assume scale tensor is [1, n] // assume scale tensor is [1, n]
Tensor<ScaleDataType> scale_k_n(f_host_tensor_descriptor(K, N, 0, BLayout{})); Tensor<ScaleDataType> scale_k_n(f_host_tensor_descriptor(K, N, 0, Row{}));
switch(config.init_method) switch(config.init_method)
{ {
...@@ -51,7 +51,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -51,7 +51,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
case 4: case 4:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{1.f, 1.f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n); ck::utils::FillUniformDistributionIntegerValue<BDataType>{1.f, 1.f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<ScaleDataType>{1.f, 1.f}(scale_k_n); ck::utils::FillUniformDistributionIntegerValue<ScaleDataType>{2.f, 2.f}(scale_k_n);
break; break;
case 5: case 5:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k); ck::utils::FillUniformDistributionIntegerValue<ADataType>{-2.f, 2.f}(a_m_k);
...@@ -64,6 +64,50 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) ...@@ -64,6 +64,50 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck::utils::FillUniformDistribution<ScaleDataType>{-1.f, 1.f}(scale_k_n); ck::utils::FillUniformDistribution<ScaleDataType>{-1.f, 1.f}(scale_k_n);
} }
#if 0
printf("Matrix A:\n");
for (int im = 0; im < M; im++)
{
for (int ik = 0; ik < K; ik++)
{
if(ik % 16 == 0){
printf("|");
}
printf(" %04x", *(reinterpret_cast<uint16_t*>(&a_m_k(im,ik))));
}
printf("\n");
}
printf("Matrix B:\n");
for (int in = 0; in < N; in++)
{
for (int ik = 0; ik < K; ik++)
{
if(ik % 16 == 0){
printf("|");
}
printf(" %02x", b_k_n(ik,in));
}
printf("\n");
}
printf("Matrix Scale:\n");
for (int in = 0; in < N; in++)
{
for (int ik = 0; ik < K; ik++)
{
if(ik % 16 == 0){
printf("|");
}
printf(" %04x", *(reinterpret_cast<uint16_t*>(&scale_k_n(ik,in))));
}
printf("\n");
}
#endif
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
......
...@@ -309,7 +309,8 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -309,7 +309,8 @@ struct Blockwise_fpAintB_GemmWMMA
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
auto scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ScaleDataType>( auto scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ScaleDataType>(
scale_thread_desc_.GetElementSpaceSize()); scale_thread_desc_.GetElementSpaceSize());
auto converted_b_thread_buf = b_thread_buf; auto converted_b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ADataType>(
b_thread_desc_.GetElementSpaceSize());
// basic intrinsic to determine loopover direction // basic intrinsic to determine loopover direction
if constexpr(MRepeat < NRepeat) if constexpr(MRepeat < NRepeat)
...@@ -345,7 +346,7 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -345,7 +346,7 @@ struct Blockwise_fpAintB_GemmWMMA
scale_thread_buf); scale_thread_buf);
// convert B from int8 to fp16, multiply scale // convert B from int8 to fp16, multiply scale
static_for<0, b_thread_buf.size(), 1>{}([&](auto i) { static_for<0, b_thread_buf.Size(), 1>{}([&](auto i) {
converted_b_thread_buf(i) = converted_b_thread_buf(i) =
scale_thread_buf[i / WmmaK] * scale_thread_buf[i / WmmaK] *
type_convert<ADataType>(b_thread_buf[i]); type_convert<ADataType>(b_thread_buf[i]);
...@@ -390,6 +391,20 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -390,6 +391,20 @@ struct Blockwise_fpAintB_GemmWMMA
else else
{ {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
// read weight scale
scale_thread_copy_.Run(
scale_block_desc_1_n0_n1_n2_1,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_block_buf,
scale_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_thread_buf);
#if 0
printf("Tid: %03d, n: %02d, scale_thread_buf: %04x\n",
get_thread_local_1d_id(), n0.value,
*(reinterpret_cast<const uint16_t*>(&scale_thread_buf[n0]))
);
#endif
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of
// k=0,kpack*1, .. // k=0,kpack*1, ..
...@@ -401,15 +416,6 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -401,15 +416,6 @@ struct Blockwise_fpAintB_GemmWMMA
b_thread_desc_, b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0), make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf); b_thread_buf);
// read weight scale
scale_thread_copy_.Run(
scale_block_desc_1_n0_n1_n2_1,
make_tuple(Number<k * WmmaK / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
scale_block_buf,
scale_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
scale_thread_buf);
// convert B from int8 to fp16, multiply scale // convert B from int8 to fp16, multiply scale
static_for<0, b_thread_buf.Size(), 1>{}([&](auto i) { static_for<0, b_thread_buf.Size(), 1>{}([&](auto i) {
converted_b_thread_buf(i) = scale_thread_buf[i / WmmaK] * converted_b_thread_buf(i) = scale_thread_buf[i / WmmaK] *
...@@ -423,7 +429,71 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -423,7 +429,71 @@ struct Blockwise_fpAintB_GemmWMMA
a_thread_desc_, a_thread_desc_,
make_tuple(I0, m0, I0, I0, I0, I0), make_tuple(I0, m0, I0, I0, I0, I0),
a_thread_buf); a_thread_buf);
if (true){
#if 0
printf("Tid: %03d, m, n, k: %02d, %02d, %02d, a_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x|\n",
get_thread_local_1d_id(), m0.value, n0.value, k.value,
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<0>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<1>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<2>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<3>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<4>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<5>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<6>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<7>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<8>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<9>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<10>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<11>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<12>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<13>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<14>{}])),
*(reinterpret_cast<const uint16_t*>(&a_thread_buf[Number<15>{}]))
);
#endif
#if 0
printf("Tid: %03d, m, n, k: %02d, %02d, %02d, b_thread_buf: %02x %02x %02x %02x| %02x %02x %02x %02x| %02x %02x %02x %02x| %02x %02x %02x %02x|\n",
get_thread_local_1d_id(), m0.value, n0.value, k.value,
b_thread_buf[Number<0>{}],
b_thread_buf[Number<1>{}],
b_thread_buf[Number<2>{}],
b_thread_buf[Number<3>{}],
b_thread_buf[Number<4>{}],
b_thread_buf[Number<5>{}],
b_thread_buf[Number<6>{}],
b_thread_buf[Number<7>{}],
b_thread_buf[Number<8>{}],
b_thread_buf[Number<9>{}],
b_thread_buf[Number<10>{}],
b_thread_buf[Number<11>{}],
b_thread_buf[Number<12>{}],
b_thread_buf[Number<13>{}],
b_thread_buf[Number<14>{}],
b_thread_buf[Number<15>{}]
);
#endif
#if 0
printf("Tid: %03d, m, n, k: %02d, %02d, %02d, converted_b_thread_buf: %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x| %04x %04x %04x %04x|\n",
get_thread_local_1d_id(), m0.value, n0.value, k.value,
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<0>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<1>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<2>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<3>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<4>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<5>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<6>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<7>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<8>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<9>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<10>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<11>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<12>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<13>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<14>{}])),
*(reinterpret_cast<const uint16_t*>(&converted_b_thread_buf[Number<15>{}]))
);
#endif
}
vector_type<ADataType, WmmaK> a_thread_vec; vector_type<ADataType, WmmaK> a_thread_vec;
vector_type<ADataType, WmmaK> b_thread_vec; vector_type<ADataType, WmmaK> b_thread_vec;
...@@ -497,7 +567,7 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -497,7 +567,7 @@ struct Blockwise_fpAintB_GemmWMMA
I1, I1,
Number<B_KRow>{}, Number<B_KRow>{},
I1, I1,
Number<B_K1>{}), I1),
make_tuple(I0, I1, I0, I0, I0, I0)); make_tuple(I0, I1, I0, I0, I0, I0));
// C[M, N, NumRegWMMA] // C[M, N, NumRegWMMA]
...@@ -587,11 +657,11 @@ struct Blockwise_fpAintB_GemmWMMA ...@@ -587,11 +657,11 @@ struct Blockwise_fpAintB_GemmWMMA
ScaleDataType, ScaleDataType,
decltype(scale_block_desc_1_n0_n1_n2_1), decltype(scale_block_desc_1_n0_n1_n2_1),
decltype(scale_thread_desc_), decltype(scale_thread_desc_),
Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, B_K1>, Sequence<WmmaK / B_K1 / B_KRow, 1, 1, B_KRow, 1, 1>,
Sequence<0, 1, 2, 3, 4, 5>, Sequence<0, 1, 2, 3, 4, 5>,
5, 5,
B_K1, 1,
B_K1>; 1>;
}; };
template <> template <>
......
...@@ -182,8 +182,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -182,8 +182,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
const auto N = b_grid_desc_n_k.GetLength(I0); const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1); const auto K = b_grid_desc_n_k.GetLength(I1);
// When K = 1, it might be scale tensor. assert(K % K1 == 0);
assert(K % K1 == 0 && K != 1);
if constexpr(BEnableLds) if constexpr(BEnableLds)
{ {
...@@ -216,6 +215,52 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -216,6 +215,52 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
} }
} }
static auto MakeScaleGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB = 0)
{
// assume Scale is [1, N]
const auto scale_grid_desc_n_k = [&]() {
const auto scale_grid_desc_nraw_kraw =
make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB));
return matrix_padder.PadBDescriptor_N_K(scale_grid_desc_nraw_kraw);
}();
const auto N = scale_grid_desc_n_k.GetLength(I0);
const auto K = scale_grid_desc_n_k.GetLength(I1);
// When K = 1, it might be scale tensor.
assert(K % K1 == 0 && K != 1);
if constexpr(BEnableLds)
{
const index_t K0 = K / K1;
return transform_tensor_descriptor(
scale_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(K0, 1)), // Reduce K1 = 1
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
else
{
constexpr auto B_KRow = 2;
constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
const auto B_KWmma = K / WmmaK;
const auto N0 = N / NPerBlock;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return transform_tensor_descriptor(
scale_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(
B_KWmma, Number<B_K0PerWmma>{}, Number<B_KRow>{}, K1Number)),
make_unmerge_transform(
make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{}));
}
}
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
{ {
const auto c_grid_desc_mraw_nraw = [&]() { const auto c_grid_desc_mraw_nraw = [&]() {
...@@ -237,7 +282,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -237,7 +282,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
// Gridwise descriptor, mapping to whole given provblem. // Gridwise descriptor, mapping to whole given provblem.
using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
using ScaleGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); using ScaleGridDesc = decltype(MakeScaleGridDescriptor(1, 1, 0));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
// GridwiseGemm // GridwiseGemm
...@@ -330,7 +375,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout, ...@@ -330,7 +375,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
{ {
a_grid_desc_ = DeviceOp::MakeAGridDescriptor(M, K, StrideA); a_grid_desc_ = DeviceOp::MakeAGridDescriptor(M, K, StrideA);
b_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, StrideB); b_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, StrideB);
scale_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, 0); scale_grid_desc_ = DeviceOp::MakeScaleGridDescriptor(K, N, 0);
c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC);
block_2_ctile_map_ = block_2_ctile_map_ =
......
...@@ -52,6 +52,12 @@ __global__ void ...@@ -52,6 +52,12 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__)) defined(__gfx1102__))
__shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size];
if (false && get_thread_local_1d_id()==0){
printf("lds_size: %lu\n", GridwiseGemm::SharedMemTrait::lds_size);
printf("lds_a_size: %d\n", GridwiseGemm::SharedMemTrait::a_block_space_size_aligned);
printf("lds_b_size: %d\n", GridwiseGemm::SharedMemTrait::b_block_space_size_aligned);
printf("lds_scale_size: %d\n", GridwiseGemm::SharedMemTrait::scale_block_space_size_aligned);
}
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid, p_b_grid,
...@@ -262,7 +268,7 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -262,7 +268,7 @@ struct GridwiseFpAintBGemm_Wmma
constexpr auto K0PerBlock = KPerBlock / K1; constexpr auto K0PerBlock = KPerBlock / K1;
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, I1),
make_tuple(I0, I1, I0)); make_tuple(I0, I1, I0));
} }
else else
...@@ -276,7 +282,7 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -276,7 +282,7 @@ struct GridwiseFpAintBGemm_Wmma
Number<K0PerWmma>{}, Number<K0PerWmma>{},
I1, I1,
I1, I1,
K1), I1),
make_tuple(I0, I1, I0, I0, I0, I0, I0)); make_tuple(I0, I1, I0, I0, I0, I0, I0));
} }
}(); }();
...@@ -424,6 +430,52 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -424,6 +430,52 @@ struct GridwiseFpAintBGemm_Wmma
return b_wave_desc; return b_wave_desc;
} }
template <typename ScaleBlockDesc_>
__host__ __device__ static constexpr auto MakeScaleWaveDescriptor(const ScaleBlockDesc_&)
{
constexpr auto scale_wave_desc = [&]() {
if constexpr(BEnableLds)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr auto B_K0 = ScaleBlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = ScaleBlockDesc_{}.GetLength(I2);
constexpr auto B_KRow = I1;
return transform_tensor_descriptor(
ScaleBlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0>{}, B_KRow)),
make_unmerge_transform(make_tuple(
Number<NRepeat>{}, Number<NWaves>{}, Number<NPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}
else
{
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr auto KWmma = ScaleBlockDesc_{}.GetLength(I0);
constexpr auto K0PerWmma = ScaleBlockDesc_{}.GetLength(I3);
constexpr auto B_KRow = ScaleBlockDesc_{}.GetLength(I4);
constexpr auto B_K1 = ScaleBlockDesc_{}.GetLength(I6);
// Workaround, Freeze transform
return make_naive_tensor_descriptor(make_tuple(Number<KWmma * K0PerWmma>{},
Number<NRepeat>{},
I1,
Number<B_KRow>{},
I1,
Number<B_K1>{}),
make_tuple(I0,
I1,
I0,
I0,
I0,
I0));
}
}();
return scale_wave_desc;
}
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
// *Caution Here repeat is shuffle repeat // *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
...@@ -590,9 +642,10 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -590,9 +642,10 @@ struct GridwiseFpAintBGemm_Wmma
: 0; : 0;
static constexpr auto a_block_space_offset = 0; static constexpr auto a_block_space_offset = 0;
static constexpr auto b_block_space_offset = a_block_space_size_aligned; static constexpr auto b_block_space_offset =
(a_block_space_offset + a_block_space_size_aligned) * sizeof(ADataType)/sizeof(BDataType);
static constexpr auto scale_block_space_offset = static constexpr auto scale_block_space_offset =
b_block_space_offset + b_block_space_size_aligned; (b_block_space_offset + b_block_space_size_aligned) * sizeof(BDataType)/sizeof(ScaleDataType);
// LDS allocation for C shuffle in LDS // LDS allocation for C shuffle in LDS
static constexpr auto c_shuffle_block_space_size = static constexpr auto c_shuffle_block_space_size =
...@@ -753,7 +806,7 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -753,7 +806,7 @@ struct GridwiseFpAintBGemm_Wmma
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<BDataType*>(p_shared) + SharedMemTrait::b_block_space_offset, static_cast<BDataType*>(p_shared) + SharedMemTrait::b_block_space_offset,
SharedMemTrait::b_block_space_size_aligned); SharedMemTrait::b_block_space_size_aligned);
// printf("b_lds_offset: %lu\n", SharedMemTrait::b_block_space_offset);
auto b_blockwise_copy = auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation, BElementwiseOperation,
...@@ -834,13 +887,15 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -834,13 +887,15 @@ struct GridwiseFpAintBGemm_Wmma
auto scale_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto scale_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ScaleDataType*>(p_shared) + SharedMemTrait::scale_block_space_offset, static_cast<ScaleDataType*>(p_shared) + SharedMemTrait::scale_block_space_offset,
SharedMemTrait::scale_block_space_size_aligned); SharedMemTrait::scale_block_space_size_aligned);
// printf("scale_lds_offset: %lu\n", SharedMemTrait::scale_block_space_offset);
auto scale_blockwise_copy = auto scale_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<K0PerBlock, NPerBlock, K1>, // Reduce slice length K1 to 1
Sequence<K0PerBlock, NPerBlock, I1>,
BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
ScaleDataType, ScaleDataType,
...@@ -851,10 +906,10 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -851,10 +906,10 @@ struct GridwiseFpAintBGemm_Wmma
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
1, 1,
1, 1,
1, // no effect
1, // no effect
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true, true,
NumGemmKPrefetchStage>( NumGemmKPrefetchStage>(
...@@ -926,7 +981,7 @@ struct GridwiseFpAintBGemm_Wmma ...@@ -926,7 +981,7 @@ struct GridwiseFpAintBGemm_Wmma
AccDataType, AccDataType,
decltype(MakeAWaveDescriptor(a_block_desc)), decltype(MakeAWaveDescriptor(a_block_desc)),
decltype(MakeBWaveDescriptor(b_block_desc)), decltype(MakeBWaveDescriptor(b_block_desc)),
decltype(MakeBWaveDescriptor(scale_block_desc)), decltype(MakeScaleWaveDescriptor(scale_block_desc)),
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
......
...@@ -581,9 +581,9 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true> ...@@ -581,9 +581,9 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
typename BBlockTransferStep, typename BBlockTransferStep,
typename ScaleGridDesc, typename ScaleGridDesc,
typename ScaleBlockDesc, typename ScaleBlockDesc,
typename ScaleBlockTransfer,
typename ScaleGridBuffer, typename ScaleGridBuffer,
typename ScaleBlockBuffer, typename ScaleBlockBuffer,
typename ScaleBlockTransfer,
typename BlockwiseGemm, typename BlockwiseGemm,
typename CThreadBuffer> typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc, __device__ static void Run(const AGridDesc& a_grid_desc,
...@@ -658,6 +658,116 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true> ...@@ -658,6 +658,116 @@ struct GridwiseGemmPipeline_v1_dequant<1, true, true>
} }
}; };
template <>
struct GridwiseGemmPipeline_v1_dequant<1, true, false>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
__host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; }
__host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop)
{
return num_loop > 1;
}
template <bool HasMainLoop,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename ScaleGridDesc,
typename ScaleBlockDesc,
typename ScaleBlockTransfer,
typename ScaleGridBuffer,
typename ScaleBlockBuffer,
typename BlockwiseGemm,
typename CThreadBuffer>
__device__ static void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
const ScaleGridDesc& scale_grid_desc,
const ScaleBlockDesc& scale_block_desc,
ScaleBlockTransfer& scale_blockwise_copy,
const ScaleGridBuffer& scale_grid_buf,
ScaleBlockBuffer& scale_block_buf,
const BlockwiseGemm& blockwise_gemm,
CThreadBuffer& c_thread_buf,
index_t num_loop)
{
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
auto b_block_buf_switch = b_block_buf;
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf);
scale_blockwise_copy.Run(scale_grid_desc, scale_grid_buf, scale_block_desc, b_block_origin_idx, scale_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
b_blockwise_copy.Run(
b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch);
block_sync_lds();
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_block_buf = b_block_buf_switch;
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, scale_block_buf, c_thread_buf);
block_sync_lds();
}
}
};
template <index_t NumPrefetch> template <index_t NumPrefetch>
struct GridwiseGemmPipelineInterwave_v1; struct GridwiseGemmPipelineInterwave_v1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment