Commit 6f774178 authored by root's avatar root
Browse files

pass device arrays as seperate args

parent 715e8dd2
...@@ -23,6 +23,9 @@ namespace tensor_operation { ...@@ -23,6 +23,9 @@ namespace tensor_operation {
namespace device { namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename ADataType,
typename BDataType,
typename EDataType,
typename GemmDesc, typename GemmDesc,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
...@@ -32,7 +35,18 @@ __global__ void ...@@ -32,7 +35,18 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl(
#if 0
const void CK_CONSTANT_ADDRESS_SPACE* a_ptr,
const void CK_CONSTANT_ADDRESS_SPACE* b_ptr,
const void CK_CONSTANT_ADDRESS_SPACE* ds_ptr,
const void CK_CONSTANT_ADDRESS_SPACE* e_ptr,
#endif
const ADataType** a_ptr_,
const BDataType** b_ptr_,
const typename GridwiseGemm::DsGridPointer* ds_ptr_,
EDataType** e_ptr_,
const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_ptr_,
const index_t group_count, const index_t group_count,
const AElementwiseOperation a_element_op, const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op, const BElementwiseOperation b_element_op,
...@@ -43,12 +57,24 @@ __global__ void ...@@ -43,12 +57,24 @@ __global__ void
const index_t block_id = get_block_1d_id(); const index_t block_id = get_block_1d_id();
#if 0
const auto a_ptr_ =
static_cast<ADataType* const*>(cast_pointer_to_generic_address_space(a_ptr));
const auto b_ptr_ =
static_cast<BDataType* const*>(cast_pointer_to_generic_address_space(b_ptr));
const auto ds_ptr_ = static_cast<typename GridwiseGemm::DsGridPointer const*>(
cast_pointer_to_generic_address_space(ds_ptr));
const auto e_ptr_ =
static_cast<EDataType* const*>(cast_pointer_to_generic_address_space(e_ptr));
#endif
const auto gemm_desc_ptr = const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const)); static_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_ptr_));
index_t left = 0; index_t left = 0;
index_t right = group_count; index_t right = group_count;
index_t group_id = index_t((left + right) / 2); index_t group_id = index_t((left + right) / 2);
while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ && while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
block_id < gemm_desc_ptr[group_id].BlockEnd_)) && block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
left <= right) left <= right)
...@@ -64,11 +90,12 @@ __global__ void ...@@ -64,11 +90,12 @@ __global__ void
group_id = index_t((left + right) / 2); group_id = index_t((left + right) / 2);
} }
#if 1
GridwiseGemm::template Run<HasMainKBlockLoop>( GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr_, a_ptr_[group_id],
gemm_desc_ptr[group_id].b_ptr_, b_ptr_[group_id],
gemm_desc_ptr[group_id].ds_ptr_, ds_ptr_[group_id],
gemm_desc_ptr[group_id].e_ptr_, e_ptr_[group_id],
p_shared, p_shared,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -78,6 +105,8 @@ __global__ void ...@@ -78,6 +105,8 @@ __global__ void
gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].block_2_etile_map_); gemm_desc_ptr[group_id].block_2_etile_map_);
#endif
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
ignore = group_count; ignore = group_count;
...@@ -323,12 +352,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -323,12 +352,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
struct GemmBiasTransKernelArg struct GemmBiasTransKernelArg
{ {
// pointers
const ADataType* a_ptr_;
const BDataType* b_ptr_;
typename GridwiseGemm::DsGridPointer ds_ptr_;
EDataType* e_ptr_;
// tensor descriptors for problem definiton // tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_; AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_; BGridDesc_N_K b_grid_desc_n_k_;
...@@ -456,12 +479,13 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -456,12 +479,13 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n); e_grid_desc_m_n);
a_ptr_.push_back(static_cast<const ADataType*>(p_As[i]));
b_ptr_.push_back(static_cast<const BDataType*>(p_Bs[i]));
ds_ptr_.push_back(p_ds_grid);
e_ptr_.push_back(static_cast<EDataType*>(p_Es[i]));
gemm_desc_kernel_arg_.push_back( gemm_desc_kernel_arg_.push_back(
GemmBiasTransKernelArg{static_cast<const ADataType*>(p_As[i]), GemmBiasTransKernelArg{a_grid_desc_m_k,
static_cast<const BDataType*>(p_Bs[i]),
p_ds_grid,
static_cast<EDataType*>(p_Es[i]),
a_grid_desc_m_k,
b_grid_desc_n_k, b_grid_desc_n_k,
ds_grid_desc_m_n, ds_grid_desc_m_n,
e_grid_desc_m_n, e_grid_desc_m_n,
...@@ -484,6 +508,11 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -484,6 +508,11 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation c_element_op_; CDEElementwiseOperation c_element_op_;
std::vector<const ADataType*> a_ptr_;
std::vector<const BDataType*> b_ptr_;
std::vector<typename GridwiseGemm::DsGridPointer> ds_ptr_;
std::vector<const EDataType*> e_ptr_;
std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_; std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
index_t grid_size_; index_t grid_size_;
...@@ -542,16 +571,70 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -542,16 +571,70 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
} }
} }
int wg_off = 0;
const int align = 4;
void* a_ptr_dev = static_cast<char*>(arg.p_workspace_) + wg_off;
hipGetErrorString(hipMemcpy(a_ptr_dev,
arg.a_ptr_.data(),
arg.a_ptr_.size() * sizeof(ADataType*),
hipMemcpyHostToDevice));
auto a_ptr = static_cast<const ADataType* const*>(arg.a_ptr_.data());
wg_off += arg.a_ptr_.size() * sizeof(ADataType*);
wg_off = math::integer_least_multiple(wg_off, align);
void* b_ptr_dev = static_cast<char*>(arg.p_workspace_) + wg_off;
hipGetErrorString(hipMemcpy(b_ptr_dev,
arg.b_ptr_.data(),
arg.b_ptr_.size() * sizeof(BDataType*),
hipMemcpyHostToDevice));
wg_off += arg.b_ptr_.size() * sizeof(BDataType*);
wg_off = math::integer_least_multiple(wg_off, align);
void* ds_ptr_dev = static_cast<char*>(arg.p_workspace_) + wg_off;
hipGetErrorString(
hipMemcpy(ds_ptr_dev,
arg.ds_ptr_.data(),
arg.ds_ptr_.size() * sizeof(typename GridwiseGemm::DsGridPointer),
hipMemcpyHostToDevice));
wg_off += arg.ds_ptr_.size() * sizeof(typename GridwiseGemm::DsGridPointer);
wg_off = math::integer_least_multiple(wg_off, align);
void* e_ptr_dev = static_cast<char*>(arg.p_workspace_) + wg_off;
hipGetErrorString(hipMemcpy(e_ptr_dev,
arg.e_ptr_.data(),
arg.e_ptr_.size() * sizeof(EDataType*),
hipMemcpyHostToDevice));
wg_off += arg.e_ptr_.size() * sizeof(EDataType*);
wg_off = math::integer_least_multiple(wg_off, align);
void* gemm_desc_dev = static_cast<char*>(arg.p_workspace_) + wg_off;
hipGetErrorString( hipGetErrorString(
hipMemcpy(arg.p_workspace_, hipMemcpy(gemm_desc_dev,
arg.gemm_desc_kernel_arg_.data(), arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg), arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice)); hipMemcpyHostToDevice));
wg_off += arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg);
wg_off = math::integer_least_multiple(wg_off, align);
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) { auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm, const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
ADataType,
BDataType,
EDataType,
GemmBiasTransKernelArg, GemmBiasTransKernelArg,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
...@@ -564,7 +647,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -564,7 +647,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
dim3(arg.grid_size_), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(arg.p_workspace_), #if 0
cast_pointer_to_constant_address_space(a_ptr_dev),
cast_pointer_to_constant_address_space(b_ptr_dev),
cast_pointer_to_constant_address_space(ds_ptr_dev),
cast_pointer_to_constant_address_space(e_ptr_dev),
#endif
static_cast<const ADataType**>(a_ptr_dev),
static_cast<const BDataType**>(b_ptr_dev),
static_cast<const typename GridwiseGemm::DsGridPointer*>(ds_ptr_dev),
static_cast<EDataType**>(e_ptr_dev),
cast_pointer_to_constant_address_space(gemm_desc_dev),
arg.gemm_desc_kernel_arg_.size(), arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
...@@ -670,7 +763,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout, ...@@ -670,7 +763,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{ {
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmBiasTransKernelArg); size_t wg_size = dynamic_cast<const Argument*>(p_arg)->group_count_ *
(sizeof(GemmBiasTransKernelArg) + sizeof(ADataType*) + sizeof(BDataType*) +
sizeof(typename GridwiseGemm::DsGridPointer) + sizeof(EDataType*));
const int align = 4;
wg_size = math::integer_least_multiple(wg_size, align);
return wg_size;
} }
}; };
......
...@@ -10,8 +10,8 @@ cmake ...@@ -10,8 +10,8 @@ cmake
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=OFF \
-D GPU_TARGETS="gfx908;gfx90a" \ -D GPU_TARGETS="gfx90a" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment