"include/device.hpp" did not exist on "7c098ddc0e48d09e767479383cd2425b03aa0f8a"
Unverified Commit d6709dc3 authored by Anthony Chang's avatar Anthony Chang Committed by GitHub
Browse files

Fix gemm-softmax-gemm-permute padding cases (#409)

* fix example; make padding on by default in example; fix argument checks

* fix Gemm1KPacK which has since regressed from PR #399
parent ce74cea4
...@@ -58,7 +58,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; ...@@ -58,7 +58,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough; using B1ElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNOPadding;
using DeviceGemmInstance = using DeviceGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
...@@ -77,7 +77,7 @@ using DeviceGemmInstance = ...@@ -77,7 +77,7 @@ using DeviceGemmInstance =
Acc0ElementOp, Acc0ElementOp,
B1ElementOp, B1ElementOp,
CElementOp, CElementOp,
GemmDefault, GemmSpec,
1, 1,
256, 256,
128, // MPerBlock 128, // MPerBlock
...@@ -166,8 +166,6 @@ int main(int argc, char* argv[]) ...@@ -166,8 +166,6 @@ int main(int argc, char* argv[])
// C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3]) // C_g0_m_g1_o = permute(C_g0_g1_m_o, [0, 2, 1, 3])
ck::index_t G0 = 7; ck::index_t G0 = 7;
ck::index_t G1 = 13; ck::index_t G1 = 13;
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
if(argc == 1) if(argc == 1)
{ {
...@@ -204,6 +202,9 @@ int main(int argc, char* argv[]) ...@@ -204,6 +202,9 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M; const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K; const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N; const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
......
...@@ -693,9 +693,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle ...@@ -693,9 +693,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
} }
// Check if C permute dimension matches GEMM + GEMM shape // Check if C permute dimension matches GEMM + GEMM shape
const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded
const index_t c_m = arg.c_grid_desc_g_m_n_.GetLength(I1); const index_t c_m = arg.c_grid_desc_m_n_.GetLength(I0);
const index_t c_gemm1n = arg.c_grid_desc_g_m_n_.GetLength(I2); const index_t c_gemm1n = arg.c_grid_desc_m_n_.GetLength(I1);
const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1);
const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1);
if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n))
......
...@@ -594,10 +594,17 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -594,10 +594,17 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset, static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
// selected_mfma.k_per_blk <= B1K1 <= selected_mfma.group_size // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
constexpr index_t Gemm1KPack = math::max( // selected_mfma.k_per_blk <= Gemm1KPack
math::gcd(MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size, B1K1), //
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
......
...@@ -645,10 +645,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -645,10 +645,17 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset, static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
// selected_mfma.k_per_blk <= B1K1 <= selected_mfma.group_size // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
constexpr index_t Gemm1KPack = math::max( // selected_mfma.k_per_blk <= Gemm1KPack
math::gcd(MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size, B1K1), //
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.k_per_blk); // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr index_t Gemm1KPack =
MfmaSelector<FloatAB, MPerXdl, NPerXdl>::selected_mfma.group_size;
auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
BlockSize, BlockSize,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment