Commit dde01029 authored by Anthony Chang's avatar Anthony Chang
Browse files

clang-format

parent 597155e8
...@@ -26,11 +26,11 @@ using F32 = float; ...@@ -26,11 +26,11 @@ using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F16;
using CDataType = F16; using CDataType = F16;
using C0DataType = F16; using C0DataType = F16;
using AccDataType = F32; using AccDataType = F32;
using CShuffleDataType = F16; using CShuffleDataType = F16;
using ALayout = ck::tensor_layout::gemm::RowMajor; using ALayout = ck::tensor_layout::gemm::RowMajor;
...@@ -39,7 +39,7 @@ using CLayout = ck::tensor_layout::gemm::RowMajor; ...@@ -39,7 +39,7 @@ using CLayout = ck::tensor_layout::gemm::RowMajor;
struct Relu struct Relu
{ {
template<typename OutT, typename InT> template <typename OutT, typename InT>
__host__ __device__ void operator()(OutT& y, const InT& x) const __host__ __device__ void operator()(OutT& y, const InT& x) const
{ {
y = x > 0 ? x : 0; y = x > 0 ? x : 0;
...@@ -187,10 +187,10 @@ int main(int argc, char* argv[]) ...@@ -187,10 +187,10 @@ int main(int argc, char* argv[])
c0_gamma_buf.ToDevice(c0_n_gamma.mData.data()); c0_gamma_buf.ToDevice(c0_n_gamma.mData.data());
c0_beta_buf.ToDevice(c0_n_beta.mData.data()); c0_beta_buf.ToDevice(c0_n_beta.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
auto acc_element_op = AccElementOp{}; auto acc_element_op = AccElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
...@@ -262,8 +262,11 @@ int main(int argc, char* argv[]) ...@@ -262,8 +262,11 @@ int main(int argc, char* argv[])
} }
else if constexpr(std::is_same<CShuffleDataType, F16>::value) else if constexpr(std::is_same<CShuffleDataType, F16>::value)
{ {
pass &= ck::utils::check_err( pass &= ck::utils::check_err(c_m_n_device_result.mData,
c_m_n_device_result.mData, c_m_n_host_result.mData, "Error: Incorrect results c", 1e-2, 1e-2); c_m_n_host_result.mData,
"Error: Incorrect results c",
1e-2,
1e-2);
} }
} }
return pass ? 0 : 1; return pass ? 0 : 1;
......
...@@ -462,8 +462,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -462,8 +462,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
acc_element_op_{acc_element_op}, acc_element_op_{acc_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_, block_2_ctile_map_)) b_grid_desc_bk0_n_bk1_,
c_grid_desc_m_n_,
block_2_ctile_map_))
{ {
c_grid_desc_mblock_mperblock_nblock_nperblock_ = c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
...@@ -519,8 +521,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator ...@@ -519,8 +521,10 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
} }
#endif #endif
if(!GridwiseGemm::CheckValidity( if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_, arg.block_2_ctile_map_)) arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_,
arg.block_2_ctile_map_))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
......
...@@ -37,7 +37,7 @@ __global__ void ...@@ -37,7 +37,7 @@ __global__ void
kernel_gemm_layernorm_xdl_cshuffle_v1( kernel_gemm_layernorm_xdl_cshuffle_v1(
const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, // MxN FloatC* __restrict__ p_c_grid, // MxN
const FloatC0* __restrict__ p_c0_bias_grid, // 1xN const FloatC0* __restrict__ p_c0_bias_grid, // 1xN
const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN
const FloatC0* __restrict__ p_c0_beta_grid, // 1xN const FloatC0* __restrict__ p_c0_beta_grid, // 1xN
...@@ -218,15 +218,20 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -218,15 +218,20 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
// Align 16 bytes (maximum LDS read/write width) // Align 16 bytes (maximum LDS read/write width)
constexpr auto c_block_size_aligned = math::integer_least_multiple( constexpr auto c_block_size_aligned =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() * sizeof(FloatCShuffle), 16) / sizeof(FloatCShuffle); math::integer_least_multiple(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() *
sizeof(FloatCShuffle),
16) /
sizeof(FloatCShuffle);
// LDS allocation for reduction workspace // LDS allocation for reduction workspace
constexpr index_t c_lds_workspace_size = BlockSize; constexpr index_t c_lds_workspace_size = BlockSize;
return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(FloatAB), sizeof(FloatAB),
c_block_size_aligned * sizeof(FloatCShuffle) + c_lds_workspace_size * sizeof(FloatReduceAcc)); c_block_size_aligned * sizeof(FloatCShuffle) +
c_lds_workspace_size * sizeof(FloatReduceAcc));
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...@@ -738,11 +743,17 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -738,11 +743,17 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize());
// Align 16 bytes (maximum LDS read/write width) // Align 16 bytes (maximum LDS read/write width)
constexpr auto c_block_size_aligned = math::integer_least_multiple( constexpr auto c_block_size_aligned =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() * sizeof(FloatCShuffle), 16) / sizeof(FloatCShuffle); math::integer_least_multiple(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() *
sizeof(FloatCShuffle),
16) /
sizeof(FloatCShuffle);
auto d_reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto d_reduce_work_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
reinterpret_cast<FloatReduceAcc*>(static_cast<FloatCShuffle*>(p_shared) + c_block_size_aligned), BlockSize); reinterpret_cast<FloatReduceAcc*>(static_cast<FloatCShuffle*>(p_shared) +
c_block_size_aligned),
BlockSize);
// Sum thread workspace // Sum thread workspace
auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>( auto d0_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatReduceAcc>(
......
...@@ -149,7 +149,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator ...@@ -149,7 +149,7 @@ struct ReferenceGemmLayernorm : public device::BaseOperator
RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_bias_, arg.c0_n_gamma_, arg.c0_n_beta_); RunLayernorm(arg.c_m_n_, acc_m_n, arg.c0_n_bias_, arg.c0_n_gamma_, arg.c0_n_beta_);
arg.c_m_n_.ForEach([&](auto& self, auto idx){ arg.c_m_n_.ForEach([&](auto& self, auto idx) {
arg.c_element_op_(self(idx[0], idx[1]), self(idx[0], idx[1])); arg.c_element_op_(self(idx[0], idx[1]), self(idx[0], idx[1]));
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment