Commit 9a4fd1bc authored by Rosty Geyyer's avatar Rosty Geyyer
Browse files

Format

parent 0c028828
......@@ -96,9 +96,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
using ABDataTypeAdjusted = conditional_t<is_same_v<ABDataType, ck::half_t>,
ck::bhalf_t,
ABDataType>;
using ABDataTypeAdjusted =
conditional_t<is_same_v<ABDataType, ck::half_t>, ck::bhalf_t, ABDataType>;
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
......@@ -488,7 +487,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataTypeAdjusted*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
static_cast<ABDataTypeAdjusted*>(p_shared),
a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ABDataTypeAdjusted*>(p_shared) + a_block_space_size_aligned,
......
......@@ -265,9 +265,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>,
ck::bhalf_t,
FloatAB>;
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
// M0/M1/M1Padding
static constexpr auto M1PerBlock = Number<ABlockLdsM1PerBlock>{};
......@@ -754,8 +752,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatABAdjusted*>(p_shared),
a_k0_m_k1_block_desc.GetElementSpaceSize());
static_cast<FloatABAdjusted*>(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatABAdjusted*>(p_shared) + a_block_space_size,
......
......@@ -135,9 +135,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>,
ck::bhalf_t,
FloatAB>;
using FloatABAdjusted = conditional_t<is_same_v<FloatAB, ck::half_t>, ck::bhalf_t, FloatAB>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment