Commit 0f49ce23 authored by Jing Zhang's avatar Jing Zhang Committed by root
Browse files

void data pointers

parent a74cbab8
...@@ -25,6 +25,9 @@ namespace device { ...@@ -25,6 +25,9 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
typename FloatA,
typename FloatB,
typename FloatC,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation> InMemoryDataOperationEnum CGlobalMemoryDataOperation>
__global__ void __global__ void
...@@ -77,9 +80,10 @@ __global__ void ...@@ -77,9 +80,10 @@ __global__ void
} }
#endif #endif
const auto p_a_grid = gemm_desc_ptr[group_id].p_a_grid; const auto p_a_grid = reinterpret_cast<const FloatA*>(gemm_desc_ptr[group_id].p_a_grid);
const auto p_b_grid = gemm_desc_ptr[group_id].p_b_grid; const auto p_b_grid = reinterpret_cast<const FloatB*>(gemm_desc_ptr[group_id].p_b_grid);
const auto p_c_grid = gemm_desc_ptr[group_id].p_c_grid; const auto p_c_grid = reinterpret_cast<FloatC*>(gemm_desc_ptr[group_id].p_c_grid);
const auto M = gemm_desc_ptr[group_id].M; const auto M = gemm_desc_ptr[group_id].M;
const auto N = gemm_desc_ptr[group_id].N; const auto N = gemm_desc_ptr[group_id].N;
const auto K = gemm_desc_ptr[group_id].K; const auto K = gemm_desc_ptr[group_id].K;
...@@ -400,9 +404,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -400,9 +404,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{ {
struct SimpleGemmArgument struct SimpleGemmArgument
{ {
const ADataType* p_a_grid; const void* p_a_grid;
const BDataType* p_b_grid; const void* p_b_grid;
EDataType* p_c_grid; void* p_c_grid;
index_t M; index_t M;
index_t N; index_t N;
...@@ -517,6 +521,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -517,6 +521,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const auto kernel = const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm, kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmArgumentType, GemmArgumentType,
ADataType,
BDataType,
EDataType,
true, true,
InMemoryDataOperationEnum::AtomicAdd>; InMemoryDataOperationEnum::AtomicAdd>;
...@@ -527,6 +534,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -527,6 +534,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const auto kernel = const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm, kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmArgumentType, GemmArgumentType,
ADataType,
BDataType,
EDataType,
true, true,
InMemoryDataOperationEnum::Set>; InMemoryDataOperationEnum::Set>;
...@@ -540,6 +550,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -540,6 +550,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const auto kernel = const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm, kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmArgumentType, GemmArgumentType,
ADataType,
BDataType,
EDataType,
false, false,
InMemoryDataOperationEnum::AtomicAdd>; InMemoryDataOperationEnum::AtomicAdd>;
...@@ -550,6 +563,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -550,6 +563,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const auto kernel = const auto kernel =
kernel_grouped_gemm_xdl_splitk<GridwiseGemm, kernel_grouped_gemm_xdl_splitk<GridwiseGemm,
GemmArgumentType, GemmArgumentType,
ADataType,
BDataType,
EDataType,
false, false,
InMemoryDataOperationEnum::Set>; InMemoryDataOperationEnum::Set>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment