Commit 9a4fd1bc authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Format

parent 0c028828
...@@ -96,9 +96,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -96,9 +96,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update // when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file // ABDataTypeAdjusted -> ABDataType throughout this file
using ABDataTypeAdjusted = conditional_t<is_same_v<ABDataType, ck::half_t>, using ABDataTypeAdjusted =
ck::bhalf_t, conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
ABDataType>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{ {
...@@ -488,7 +487,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -488,7 +487,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataTypeAdjusted*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ABDataTypeAdjusted*>(p_shared),
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataTypeAdjusted*>(p_shared) + a_block_space_size_aligned, static_cast<ABDataTypeAdjusted*>(p_shared) + a_block_space_size_aligned,
......
...@@ -265,9 +265,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -265,9 +265,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update // when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file // FloatABAdjusted -> FloatAB throughout this file
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
ck::bhalf_t,
FloatAB>;
// M0/M1/M1Padding // M0/M1/M1Padding
static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{}; static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{};
...@@ -754,8 +752,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight ...@@ -754,8 +752,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatABAdjusted*>(p_shared), static_cast<FloatABAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size, static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size,
......
...@@ -135,9 +135,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -135,9 +135,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update // when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file // FloatABAdjusted -> FloatAB throughout this file
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
ck::bhalf_t,
FloatAB>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment