Commit f58137f2 authored by carlushuang's avatar carlushuang
Browse files

fix several bug

parent 8eaed8b3
...@@ -35,14 +35,15 @@ using AccDataType = F32; ...@@ -35,14 +35,15 @@ using AccDataType = F32;
using CDataType = F16; using CDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Row;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle
// clang-format off // clang-format off
...@@ -50,7 +51,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu ...@@ -50,7 +51,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; // < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 16, 1, 8>, 8>;
// clang-format on // clang-format on
#include "run_splitK_gemm_example.inc" #include "run_splitK_gemm_example.inc"
......
...@@ -111,14 +111,34 @@ struct ThreadGroupTensorSliceTransfer_v6r1r2 ...@@ -111,14 +111,34 @@ struct ThreadGroupTensorSliceTransfer_v6r1r2
} }
} }
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin)
{ {
threadwise_transfer_.SetSrcSliceOrigin(src_desc, src_slice_origin_idx); if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
}
} }
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_block_slice_origin)
{ {
threadwise_transfer_.SetDstSliceOrigin(dst_desc, dst_slice_origin_idx); if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
dst_block_slice_origin + thread_data_idx_begin);
}
} }
private: private:
......
...@@ -833,7 +833,8 @@ struct BlockToCTileMap_GemmStreamK ...@@ -833,7 +833,8 @@ struct BlockToCTileMap_GemmStreamK
printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, " printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, " "sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, " "sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d\n", "k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
"sk_tiles:%u, workspace(acc float):%u\n",
num_cu, num_cu,
occupancy, occupancy,
get_grid_dims().x, get_grid_dims().x,
...@@ -846,7 +847,10 @@ struct BlockToCTileMap_GemmStreamK ...@@ -846,7 +847,10 @@ struct BlockToCTileMap_GemmStreamK
dp_iters_per_block, dp_iters_per_block,
dp_num_blocks, dp_num_blocks,
k_iters_per_tile.get(), k_iters_per_tile.get(),
k_iters_per_big_block); k_iters_per_big_block,
reduction_start_block_idx,
get_sk_tiles(),
get_workspace_size(sizeof(float)));
} }
__host__ __device__ uint32_t get_sk_total_iters() const __host__ __device__ uint32_t get_sk_total_iters() const
......
...@@ -23,9 +23,8 @@ namespace ck { ...@@ -23,9 +23,8 @@ namespace ck {
template <typename GridwiseGemm> template <typename GridwiseGemm>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
#endif #endif
// kernel_gemm_xdlops_streamk(typename GridwiseGemm::Argument karg)
kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid, kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid,
const typename GridwiseGemm::FloatAB* p_b_grid, const typename GridwiseGemm::FloatAB* p_b_grid,
typename GridwiseGemm::FloatC* p_c_grid, typename GridwiseGemm::FloatC* p_c_grid,
...@@ -43,7 +42,6 @@ __global__ void ...@@ -43,7 +42,6 @@ __global__ void
__shared__ uint8_t p_shared[shared_size]; __shared__ uint8_t p_shared[shared_size];
// GridwiseGemm::Run(karg, static_cast<void*>(p_shared));
GridwiseGemm::Run(p_a_grid, GridwiseGemm::Run(p_a_grid,
p_b_grid, p_b_grid,
p_c_grid, p_c_grid,
...@@ -549,6 +547,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -549,6 +547,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
{ {
// descriptors // descriptors
constexpr auto cluster_length_reduce = GetClusterLengthReduction(); constexpr auto cluster_length_reduce = GetClusterLengthReduction();
constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce);
const auto reduce_thread_cluster_idx =
reduce_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0];
const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1];
constexpr auto MReduceIters = constexpr auto MReduceIters =
math::integer_divide_ceil(Number<MPerBlock>{}, cluster_length_reduce.At(I0)); math::integer_divide_ceil(Number<MPerBlock>{}, cluster_length_reduce.At(I0));
...@@ -560,13 +563,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -560,13 +563,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// make_tuple(Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); // make_tuple(Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed( constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{})); make_tuple(I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed( constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<1>{}, make_tuple(I1, I1, I1, Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
Number<1>{},
Number<1>{},
Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor(); constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor();
...@@ -627,7 +627,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -627,7 +627,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector,
1, // SrcScalarStrideInVector, 1, // SrcScalarStrideInVector,
false // SrcResetCoordinateAfterRun, false // SrcResetCoordinateAfterRun,
>{c_partial_acc_block_m_n, make_multi_index(0, 0)}; >{c_partial_acc_block_m_n,
make_multi_index(thread_m_cluster_id,
thread_n_cluster_id *
CBlockTransferScalarPerVector_NWaveNPerXDL)};
auto acc_store = ThreadwiseTensorSliceTransfer_v1r3< auto acc_store = ThreadwiseTensorSliceTransfer_v1r3<
FloatAcc, // SrcData, FloatAcc, // SrcData,
...@@ -635,18 +638,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -635,18 +638,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
decltype(acc_thread_buf_store_desc), // SrcDesc, decltype(acc_thread_buf_store_desc), // SrcDesc,
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc, decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc,
CElementwiseOperation, // ElementwiseOperation, CElementwiseOperation, // ElementwiseOperation,
Sequence<0, 0, 0, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths, Sequence<1, 1, 1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths,
Sequence<0, 1, 2, 3>, // DimAccessOrder, Sequence<0, 1, 2, 3>, // DimAccessOrder,
2, // DstVectorDim, 3, // DstVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector,
InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp, InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp,
3, // DstScalarStrideInVector, 1, // DstScalarStrideInVector,
false // DstResetCoordinateAfterRun, false // DstResetCoordinateAfterRun,
>{c_grid_desc_mblock_mperblock_nblock_nperblock, >{c_grid_desc_mblock_mperblock_nblock_nperblock,
make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
0, thread_m_cluster_id,
__builtin_amdgcn_readfirstlane(spatial_idx[I1]), __builtin_amdgcn_readfirstlane(spatial_idx[I1]),
0), thread_n_cluster_id *
CBlockTransferScalarPerVector_NWaveNPerXDL),
CElementwiseOperation{}}; CElementwiseOperation{}};
// block synchronization // block synchronization
...@@ -659,8 +663,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -659,8 +663,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) { static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) {
for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++) for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
{ {
auto c_partial_acc_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_partial_acc_buf =
static_cast<FloatAcc*>(p_workspace) + i, make_dynamic_buffer<AddressSpaceEnum::Global,
amd_buffer_coherence_bits::glc>(
reinterpret_cast<FloatAcc*>(p_workspace) +
i * c_partial_acc_block_m_n.GetElementSpaceSize(),
c_partial_acc_block_m_n.GetElementSpaceSize()); c_partial_acc_block_m_n.GetElementSpaceSize());
acc_load.Run(c_partial_acc_block_m_n, acc_load.Run(c_partial_acc_block_m_n,
...@@ -850,12 +857,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -850,12 +857,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle(); GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatCShuffle*>(p_shared_block), reinterpret_cast<FloatCShuffle*>(p_shared_block),
c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize()); c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
auto c_partial_acc_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto c_partial_acc_buf =
static_cast<FloatAcc*>(p_workspace) + block_acc_offset, make_dynamic_buffer<AddressSpaceEnum::Global, amd_buffer_coherence_bits::glc>(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle.GetElementSpaceSize()); reinterpret_cast<FloatAcc*>(p_workspace) + block_acc_offset,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_block_desc_mblock_mpershuffle_nblock_npershuffle, c_block_desc_mblock_mpershuffle_nblock_npershuffle,
...@@ -984,7 +993,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk ...@@ -984,7 +993,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
3, // index_t VectorDim, 3, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector,
true, // bool ThreadTransferSrcResetCoordinateAfterRun, true, // bool ThreadTransferSrcResetCoordinateAfterRun,
false> // bool ThreadTransferDstResetCoordinateAfterRun true> // bool ThreadTransferDstResetCoordinateAfterRun
{c_block_desc_mblock_mpershuffle_nblock_npershuffle, {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
make_multi_index(0, 0, 0, 0), make_multi_index(0, 0, 0, 0),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
......
...@@ -685,12 +685,12 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -685,12 +685,12 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset, dst_wave_addr_offset,
0); static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<1>{}], llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<1>{}],
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(float), dst_wave_addr_offset + 4 * sizeof(float),
0); static_cast<index_t>(coherence));
} }
} }
else if constexpr(is_same<T, half_t>::value) else if constexpr(is_same<T, half_t>::value)
......
...@@ -20,8 +20,9 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) ...@@ -20,8 +20,9 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
*/ */
struct DeviceMem struct DeviceMem
{ {
DeviceMem() = delete; DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {}
DeviceMem(std::size_t mem_size); DeviceMem(std::size_t mem_size);
void Realloc(std::size_t mem_size);
void* GetDeviceBuffer() const; void* GetDeviceBuffer() const;
std::size_t GetBufferSize() const; std::size_t GetBufferSize() const;
void ToDevice(const void* p) const; void ToDevice(const void* p) const;
......
...@@ -10,20 +10,49 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) ...@@ -10,20 +10,49 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize)); hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
} }
void DeviceMem::Realloc(std::size_t mem_size)
{
if(mpDeviceBuf)
{
hip_check_error(hipFree(mpDeviceBuf));
}
mMemSize = mem_size;
hip_check_error(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; } void* DeviceMem::GetDeviceBuffer() const { return mpDeviceBuf; }
std::size_t DeviceMem::GetBufferSize() const { return mMemSize; } std::size_t DeviceMem::GetBufferSize() const { return mMemSize; }
void DeviceMem::ToDevice(const void* p) const void DeviceMem::ToDevice(const void* p) const
{ {
hip_check_error(hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice)); if(mpDeviceBuf)
{
hip_check_error(
hipMemcpy(mpDeviceBuf, const_cast<void*>(p), mMemSize, hipMemcpyHostToDevice));
}
} }
void DeviceMem::FromDevice(void* p) const void DeviceMem::FromDevice(void* p) const
{ {
if(mpDeviceBuf)
{
hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); hip_check_error(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost));
}
} }
void DeviceMem::SetZero() const { hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize)); } void DeviceMem::SetZero() const
{
if(mpDeviceBuf)
{
hip_check_error(hipMemset(mpDeviceBuf, 0, mMemSize));
}
}
DeviceMem::~DeviceMem() { hip_check_error(hipFree(mpDeviceBuf)); } DeviceMem::~DeviceMem()
{
if(mpDeviceBuf)
{
hip_check_error(hipFree(mpDeviceBuf));
}
}
...@@ -155,6 +155,13 @@ bool profile_gemm_streamk_impl(int do_verification, ...@@ -155,6 +155,13 @@ bool profile_gemm_streamk_impl(int do_verification,
b_element_op, b_element_op,
c_element_op, c_element_op,
NumSKBlocks); NumSKBlocks);
DeviceMem workspace;
std::size_t workspace_size = op_ptr->GetWorkSpaceSize(argument_ptr);
if(workspace_size != 0)
{
workspace.Realloc(workspace_size);
op_ptr->SetWorkSpacePointer(argument_ptr, workspace.GetDeviceBuffer());
}
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment