Commit 30109f67 authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent 52ce27bc
...@@ -504,23 +504,6 @@ struct DeviceGemmBiasActivation_Xdl_CShuffle ...@@ -504,23 +504,6 @@ struct DeviceGemmBiasActivation_Xdl_CShuffle
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_, arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_m_n_, arg.c_grid_desc_m_n_,
...@@ -535,10 +518,9 @@ struct DeviceGemmBiasActivation_Xdl_CShuffle ...@@ -535,10 +518,9 @@ struct DeviceGemmBiasActivation_Xdl_CShuffle
const auto K = const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
float ave_time = 0; auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
const auto kernel = kernel_gemm_bias_activation_xdl_cshuffle< const auto kernel = kernel_gemm_bias_activation_xdl_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
...@@ -550,55 +532,34 @@ struct DeviceGemmBiasActivation_Xdl_CShuffle ...@@ -550,55 +532,34 @@ struct DeviceGemmBiasActivation_Xdl_CShuffle
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; has_main_loop>;
ave_time = return launch_and_time_kernel(stream_config,
launch_and_time_kernel(stream_config, kernel,
kernel, dim3(grid_size),
dim3(grid_size), dim3(BlockSize),
dim3(BlockSize), 0,
0, arg.p_a_grid_,
arg.p_a_grid_, arg.p_b_grid_,
arg.p_b_grid_, arg.p_c_grid_,
arg.p_c_grid_, arg.a_element_op_,
arg.a_element_op_, arg.b_element_op_,
arg.b_element_op_, arg.c_element_op_,
arg.c_element_op_, arg.a_grid_desc_ak0_m_ak1_,
arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_,
arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, arg.block_2_ctile_map_);
arg.block_2_ctile_map_); };
float ave_time = 0;
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
ave_time = launch_kernel(integral_constant<bool, true>{});
} }
else else
{ {
const auto kernel = kernel_gemm_bias_activation_xdl_cshuffle< ave_time = launch_kernel(integral_constant<bool, false>{});
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
ave_time =
launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_ctile_map_);
} }
return ave_time; return ave_time;
......
#pragma once #pragma once
#include "common_header.hpp" #include "common_header.hpp"
#include "multi_index_transform_helper.hpp" #include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -235,6 +236,7 @@ struct GridwiseGemmBiasActivation_k0mk1_k0nk1_mn_xdl_cshuffle ...@@ -235,6 +236,7 @@ struct GridwiseGemmBiasActivation_k0mk1_k0nk1_mn_xdl_cshuffle
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// divide block work by [M, N] // divide block work by [M, N]
const auto block_work_idx = const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment