Commit 6f774178 authored by root's avatar root
Browse files

pass device arrays as seperate args

parent 715e8dd2
......@@ -23,6 +23,9 @@ namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename ADataType,
typename BDataType,
typename EDataType,
typename GemmDesc,
typename AElementwiseOperation,
typename BElementwiseOperation,
......@@ -32,23 +35,46 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
kernel_grouped_gemm_xdl(
#if 0
const void CK_CONSTANT_ADDRESS_SPACE* a_ptr,
const void CK_CONSTANT_ADDRESS_SPACE* b_ptr,
const void CK_CONSTANT_ADDRESS_SPACE* ds_ptr,
const void CK_CONSTANT_ADDRESS_SPACE* e_ptr,
#endif
const ADataType** a_ptr_,
const BDataType** b_ptr_,
const typename GridwiseGemm::DsGridPointer* ds_ptr_,
EDataType** e_ptr_,
const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_ptr_,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t block_id = get_block_1d_id();
#if 0
const auto a_ptr_ =
static_cast<ADataType* const*>(cast_pointer_to_generic_address_space(a_ptr));
const auto b_ptr_ =
static_cast<BDataType* const*>(cast_pointer_to_generic_address_space(b_ptr));
const auto ds_ptr_ = static_cast<typename GridwiseGemm::DsGridPointer const*>(
cast_pointer_to_generic_address_space(ds_ptr));
const auto e_ptr_ =
static_cast<EDataType* const*>(cast_pointer_to_generic_address_space(e_ptr));
#endif
const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
static_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_ptr_));
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
left <= right)
......@@ -64,11 +90,12 @@ __global__ void
group_id = index_t((left + right) / 2);
}
#if 1
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr_,
gemm_desc_ptr[group_id].b_ptr_,
gemm_desc_ptr[group_id].ds_ptr_,
gemm_desc_ptr[group_id].e_ptr_,
a_ptr_[group_id],
b_ptr_[group_id],
ds_ptr_[group_id],
e_ptr_[group_id],
p_shared,
a_element_op,
b_element_op,
......@@ -78,6 +105,8 @@ __global__ void
gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].block_2_etile_map_);
#endif
#else
ignore = gemm_descs_const;
ignore = group_count;
......@@ -323,12 +352,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
struct GemmBiasTransKernelArg
{
// pointers
const ADataType* a_ptr_;
const BDataType* b_ptr_;
typename GridwiseGemm::DsGridPointer ds_ptr_;
EDataType* e_ptr_;
// tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
......@@ -456,12 +479,13 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n);
a_ptr_.push_back(static_cast<const ADataType*>(p_As[i]));
b_ptr_.push_back(static_cast<const BDataType*>(p_Bs[i]));
ds_ptr_.push_back(p_ds_grid);
e_ptr_.push_back(static_cast<EDataType*>(p_Es[i]));
gemm_desc_kernel_arg_.push_back(
GemmBiasTransKernelArg{static_cast<const ADataType*>(p_As[i]),
static_cast<const BDataType*>(p_Bs[i]),
p_ds_grid,
static_cast<EDataType*>(p_Es[i]),
a_grid_desc_m_k,
GemmBiasTransKernelArg{a_grid_desc_m_k,
b_grid_desc_n_k,
ds_grid_desc_m_n,
e_grid_desc_m_n,
......@@ -484,6 +508,11 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
BElementwiseOperation b_element_op_;
CDEElementwiseOperation c_element_op_;
std::vector<const ADataType*> a_ptr_;
std::vector<const BDataType*> b_ptr_;
std::vector<typename GridwiseGemm::DsGridPointer> ds_ptr_;
std::vector<const EDataType*> e_ptr_;
std::vector<GemmBiasTransKernelArg> gemm_desc_kernel_arg_;
index_t grid_size_;
......@@ -542,16 +571,70 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
}
}
int wg_off = 0;
const int align = 4;
void* a_ptr_dev = static_cast<char*>(arg.p_workspace_) + wg_off;
hipGetErrorString(hipMemcpy(a_ptr_dev,
arg.a_ptr_.data(),
arg.a_ptr_.size() * sizeof(ADataType*),
hipMemcpyHostToDevice));
auto a_ptr = static_cast<const ADataType* const*>(arg.a_ptr_.data());
wg_off += arg.a_ptr_.size() * sizeof(ADataType*);
wg_off = math::integer_least_multiple(wg_off, align);
void* b_ptr_dev = static_cast<char*>(arg.p_workspace_) + wg_off;
hipGetErrorString(hipMemcpy(b_ptr_dev,
arg.b_ptr_.data(),
arg.b_ptr_.size() * sizeof(BDataType*),
hipMemcpyHostToDevice));
wg_off += arg.b_ptr_.size() * sizeof(BDataType*);
wg_off = math::integer_least_multiple(wg_off, align);
void* ds_ptr_dev = static_cast<char*>(arg.p_workspace_) + wg_off;
hipGetErrorString(
hipMemcpy(arg.p_workspace_,
hipMemcpy(ds_ptr_dev,
arg.ds_ptr_.data(),
arg.ds_ptr_.size() * sizeof(typename GridwiseGemm::DsGridPointer),
hipMemcpyHostToDevice));
wg_off += arg.ds_ptr_.size() * sizeof(typename GridwiseGemm::DsGridPointer);
wg_off = math::integer_least_multiple(wg_off, align);
void* e_ptr_dev = static_cast<char*>(arg.p_workspace_) + wg_off;
hipGetErrorString(hipMemcpy(e_ptr_dev,
arg.e_ptr_.data(),
arg.e_ptr_.size() * sizeof(EDataType*),
hipMemcpyHostToDevice));
wg_off += arg.e_ptr_.size() * sizeof(EDataType*);
wg_off = math::integer_least_multiple(wg_off, align);
void* gemm_desc_dev = static_cast<char*>(arg.p_workspace_) + wg_off;
hipGetErrorString(
hipMemcpy(gemm_desc_dev,
arg.gemm_desc_kernel_arg_.data(),
arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg),
hipMemcpyHostToDevice));
wg_off += arg.gemm_desc_kernel_arg_.size() * sizeof(GemmBiasTransKernelArg);
wg_off = math::integer_least_multiple(wg_off, align);
float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_) {
const auto kernel = kernel_grouped_gemm_xdl<GridwiseGemm,
ADataType,
BDataType,
EDataType,
GemmBiasTransKernelArg,
AElementwiseOperation,
BElementwiseOperation,
......@@ -564,7 +647,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.p_workspace_),
#if 0
cast_pointer_to_constant_address_space(a_ptr_dev),
cast_pointer_to_constant_address_space(b_ptr_dev),
cast_pointer_to_constant_address_space(ds_ptr_dev),
cast_pointer_to_constant_address_space(e_ptr_dev),
#endif
static_cast<const ADataType**>(a_ptr_dev),
static_cast<const BDataType**>(b_ptr_dev),
static_cast<const typename GridwiseGemm::DsGridPointer*>(ds_ptr_dev),
static_cast<EDataType**>(e_ptr_dev),
cast_pointer_to_constant_address_space(gemm_desc_dev),
arg.gemm_desc_kernel_arg_.size(),
arg.a_element_op_,
arg.b_element_op_,
......@@ -670,7 +763,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GemmBiasTransKernelArg);
size_t wg_size = dynamic_cast<const Argument*>(p_arg)->group_count_ *
(sizeof(GemmBiasTransKernelArg) + sizeof(ADataType*) + sizeof(BDataType*) +
sizeof(typename GridwiseGemm::DsGridPointer) + sizeof(EDataType*));
const int align = 4;
wg_size = math::integer_least_multiple(wg_size, align);
return wg_size;
}
};
......
......@@ -10,8 +10,8 @@ cmake
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS="gfx908;gfx90a" \
-D BUILD_DEV=OFF \
-D GPU_TARGETS="gfx90a" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment