Commit c0972543 authored by Jing Zhang's avatar Jing Zhang
Browse files

clean; fixed comments

parent 1c485e01
...@@ -56,9 +56,6 @@ struct SimpleDeviceMem ...@@ -56,9 +56,6 @@ struct SimpleDeviceMem
int main() int main()
{ {
std::mt19937 gen(19391);
std::uniform_int_distribution<> distrib(1, 10);
std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideEs; std::vector<int> Ms, Ns, Ks, StrideAs, StrideBs, StrideEs;
int sum_of_m = 0; int sum_of_m = 0;
...@@ -123,7 +120,7 @@ int main() ...@@ -123,7 +120,7 @@ int main()
e_dev_bufs.emplace_back(sizeof(EDataType) * e_dev_bufs.emplace_back(sizeof(EDataType) *
f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{})); f_matrix_space_size(Ms[i], Ns[i], StrideEs[i], ELayout{}));
gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 0, StrideBs[i], 0, {0}}); gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], 1, StrideBs[i], 1, {0}});
p_e.push_back(e_dev_bufs[i].GetDeviceBuffer()); p_e.push_back(e_dev_bufs[i].GetDeviceBuffer());
...@@ -248,7 +245,6 @@ int main() ...@@ -248,7 +245,6 @@ int main()
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
SimpleDeviceMem gemm_desc_workspace(op_ptr->GetWorkSpaceSize(argument_ptr.get())); SimpleDeviceMem gemm_desc_workspace(op_ptr->GetWorkSpaceSize(argument_ptr.get()));
// op_ptr->SetWorkSpacePointer(argument_ptr.get(), gemm_desc_workspace.GetDeviceBuffer());
if(op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
......
...@@ -202,13 +202,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -202,13 +202,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer());
gemm_descs.push_back({sum_of_m, gemm_descs.push_back(
problem_size.Ns[i], {1, problem_size.Ns[i], problem_size.Ks[i], 1, problem_size.stride_Bs[i], 1, {0}});
problem_size.Ks[i],
0,
problem_size.stride_Bs[i],
0,
{0}});
grouped_gemm_kernel_args_.push_back( grouped_gemm_kernel_args_.push_back(
{a_tensors_device[i]->GetDeviceBuffer(), {a_tensors_device[i]->GetDeviceBuffer(),
...@@ -320,8 +315,7 @@ int main(int argc, char* argv[]) ...@@ -320,8 +315,7 @@ int main(int argc, char* argv[])
problem_size.group_count = 16; problem_size.group_count = 16;
problem_size.Ms = { problem_size.Ms = {167, 0, 177, 181, 153, 0, 156, 173, 645, 150, 204, 184, 168, 156, 168, 148};
167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 204, 184, 168, 156, 168, 148};
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
......
...@@ -188,9 +188,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -188,9 +188,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
gemm_descs.push_back({sum_of_m, gemm_descs.push_back({sum_of_m,
problem_size.Ns[i], problem_size.Ns[i],
problem_size.Ks[i], problem_size.Ks[i],
problem_size.stride_As[i], 1,
problem_size.stride_Bs[i], problem_size.stride_Bs[i],
problem_size.stride_Cs[i], 1,
{}}); {}});
grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(), grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(),
...@@ -223,8 +223,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -223,8 +223,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument)); DeviceMem gemm_desc_workspace(gemm.GetWorkSpaceSize(&argument));
// gemm.SetWorkSpacePointer(&argument, gemm_desc_workspace.GetDeviceBuffer());
hip_check_error(hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(), hip_check_error(hipMemcpy(gemm_desc_workspace.GetDeviceBuffer(),
grouped_gemm_kernel_args_.data(), grouped_gemm_kernel_args_.data(),
gemm.GetWorkSpaceSize(&argument), gemm.GetWorkSpaceSize(&argument),
...@@ -286,7 +284,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -286,7 +284,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
return pass; return pass;
} }
// int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
ProblemSize problem_size; ProblemSize problem_size;
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <vector> #include <array>
#include "device_grouped_gemm.hpp" #include "device_grouped_gemm.hpp"
......
...@@ -56,47 +56,11 @@ __global__ void ...@@ -56,47 +56,11 @@ __global__ void
const auto gemm_desc_ptr = const auto gemm_desc_ptr =
reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const)); reinterpret_cast<const GemmDesc*>(cast_pointer_to_generic_address_space(gemm_descs_const));
#if 0
index_t left = 0;
index_t right = group_count;
index_t group_id = index_t((left + right) / 2);
while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ &&
block_id < gemm_desc_ptr[group_id].BlockEnd_)) &&
left <= right)
{
if(block_id < gemm_desc_ptr[group_id].BlockStart_)
{
right = group_id;
}
else
{
left = group_id;
}
group_id = index_t((left + right) / 2);
}
#endif
const index_t group_id = block_id / grid_size_grp; const index_t group_id = block_id / grid_size_grp;
if(group_id >= group_count) if(group_id >= group_count)
return; return;
#if 0
GridwiseGemm::template Run<HasMainKBlockLoop>(
gemm_desc_ptr[group_id].a_ptr_,
gemm_desc_ptr[group_id].b_ptr_,
gemm_desc_ptr[group_id].ds_ptr_,
gemm_desc_ptr[group_id].e_ptr_,
p_shared,
a_element_op,
b_element_op,
c_element_op,
gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_,
gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_,
gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
gemm_desc_ptr[group_id].block_2_etile_map_);
#else
const index_t M = gemm_desc_ptr[group_id].M; const index_t M = gemm_desc_ptr[group_id].M;
const index_t N = gemm_desc_ptr[group_id].N; const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K; const index_t K = gemm_desc_ptr[group_id].K;
...@@ -158,9 +122,6 @@ __global__ void ...@@ -158,9 +122,6 @@ __global__ void
m_id += 1; m_id += 1;
} while(m_id < m_loops); } while(m_id < m_loops);
#endif
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
ignore = group_count; ignore = group_count;
...@@ -644,9 +605,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -644,9 +605,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n);
// std::cout << "grp id: " << group_id << " grid_size: " << grid_size_grp <<
// std::endl;
const index_t BlockStart = grid_size_; const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
...@@ -731,11 +689,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -731,11 +689,9 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
{ {
bool has_main_k_block_loop = true; bool has_main_k_block_loop = true;
#if 1
std::vector<GroupedGemmKernelArgument<NumDTensor>> grouped_gemm_kernel_args; std::vector<GroupedGemmKernelArgument<NumDTensor>> grouped_gemm_kernel_args;
grouped_gemm_kernel_args.reserve(arg.gemm_desc_kernel_arg_.size()); grouped_gemm_kernel_args.reserve(arg.gemm_desc_kernel_arg_.size());
#endif
for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++)
{ {
...@@ -788,7 +744,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -788,7 +744,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); throw std::runtime_error("wrong! not all gemm has_main_k_block_loop");
} }
#if 1
grouped_gemm_kernel_args.push_back( grouped_gemm_kernel_args.push_back(
GroupedGemmKernelArgument<NumDTensor>{arg.gemm_desc_kernel_arg_[i].a_ptr_, GroupedGemmKernelArgument<NumDTensor>{arg.gemm_desc_kernel_arg_[i].a_ptr_,
arg.gemm_desc_kernel_arg_[i].b_ptr_, arg.gemm_desc_kernel_arg_[i].b_ptr_,
...@@ -801,7 +756,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -801,7 +756,6 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
arg.gemm_desc_kernel_arg_[i].StrideB_, arg.gemm_desc_kernel_arg_[i].StrideB_,
arg.gemm_desc_kernel_arg_[i].StrideDs_, arg.gemm_desc_kernel_arg_[i].StrideDs_,
arg.gemm_desc_kernel_arg_[i].StrideE_}); arg.gemm_desc_kernel_arg_[i].StrideE_});
#endif
} }
float ave_time = 0; float ave_time = 0;
......
...@@ -75,7 +75,6 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan ...@@ -75,7 +75,6 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_normalization_instan
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_softmax_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batchnorm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_bias_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_gemm_fastgelu_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance) target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment