Commit 4c5fe81e authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed examples; add async_mem_set

parent 7965d66a
...@@ -299,8 +299,8 @@ int main(int argc, char* argv[]) ...@@ -299,8 +299,8 @@ int main(int argc, char* argv[])
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
problem_size.Ms.push_back(256 + 256 * i); problem_size.Ms.push_back(256 + 256 * i);
problem_size.Ns.push_back(128 + 128 * i); problem_size.Ns.push_back(256);
problem_size.Ks.push_back(128 + 64 * i); problem_size.Ks.push_back(128);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);
......
...@@ -300,8 +300,8 @@ int main(int argc, char* argv[]) ...@@ -300,8 +300,8 @@ int main(int argc, char* argv[])
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
problem_size.Ms.push_back(256 + 256 * i); problem_size.Ms.push_back(256 + 256 * i);
problem_size.Ns.push_back(128 + 128 * i); problem_size.Ns.push_back(256);
problem_size.Ks.push_back(128 + 64 * i); problem_size.Ks.push_back(128);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);
......
...@@ -59,7 +59,9 @@ struct BaseOperator ...@@ -59,7 +59,9 @@ struct BaseOperator
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const virtual void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const
{ {
assert(p_arg); assert(p_arg);
p_arg->p_workspace_ = p_workspace; p_arg->p_workspace_ = p_workspace;
......
...@@ -817,12 +817,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -817,12 +817,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return arg.group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>); return arg.group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>);
} }
void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& stream_config = StreamConfig{}) const override
{ {
auto p_arg_ = dynamic_cast<Argument*>(p_arg); auto p_arg_ = dynamic_cast<Argument*>(p_arg);
p_arg_->p_workspace_ = p_workspace; p_arg_->p_workspace_ = p_workspace;
hip_check_error(hipMemset(p_workspace, 0, GetWorkSpaceSize(p_arg))); hip_check_error(
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_));
} }
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment