"...git@developer.sourcefind.cn:OpenDAS/llama-factory.git" did not exist on "1bc2def59c5cb24e8cb306d262424e6a4a7e7338"
Unverified Commit f741895f authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Merge branch 'develop' into barkocot/lwpck-1063-dev

parents e5f05e71 c7d5c772
def rocmnode(name) { def rocmnode(name) {
return 'rocmtest && miopen && ' + name return '(rocmtest || miopen) && ' + name
} }
def show_node_info() { def show_node_info() {
......
...@@ -299,8 +299,8 @@ int main(int argc, char* argv[]) ...@@ -299,8 +299,8 @@ int main(int argc, char* argv[])
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
problem_size.Ms.push_back(256 + 256 * i); problem_size.Ms.push_back(256 + 256 * i);
problem_size.Ns.push_back(128 + 128 * i); problem_size.Ns.push_back(256);
problem_size.Ks.push_back(128 + 64 * i); problem_size.Ks.push_back(128);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);
......
...@@ -300,8 +300,8 @@ int main(int argc, char* argv[]) ...@@ -300,8 +300,8 @@ int main(int argc, char* argv[])
for(int i = 0; i < problem_size.group_count; i++) for(int i = 0; i < problem_size.group_count; i++)
{ {
problem_size.Ms.push_back(256 + 256 * i); problem_size.Ms.push_back(256 + 256 * i);
problem_size.Ns.push_back(128 + 128 * i); problem_size.Ns.push_back(256);
problem_size.Ks.push_back(128 + 64 * i); problem_size.Ks.push_back(128);
problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_As.push_back(problem_size.Ks[i]);
problem_size.stride_Bs.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]);
......
...@@ -59,7 +59,9 @@ struct BaseOperator ...@@ -59,7 +59,9 @@ struct BaseOperator
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const virtual void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const
{ {
assert(p_arg); assert(p_arg);
p_arg->p_workspace_ = p_workspace; p_arg->p_workspace_ = p_workspace;
......
...@@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType, ...@@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -354,7 +354,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -354,7 +354,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -821,7 +821,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -821,7 +821,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -226,7 +226,9 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -226,7 +226,9 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
} }
} }
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -817,12 +817,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -817,12 +817,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return arg.group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>); return arg.group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>);
} }
void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& stream_config = StreamConfig{}) const override
{ {
auto p_arg_ = dynamic_cast<Argument*>(p_arg); auto p_arg_ = dynamic_cast<Argument*>(p_arg);
p_arg_->p_workspace_ = p_workspace; p_arg_->p_workspace_ = p_workspace;
hip_check_error(hipMemset(p_workspace, 0, GetWorkSpaceSize(p_arg))); hip_check_error(
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_));
} }
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
......
...@@ -577,7 +577,9 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp ...@@ -577,7 +577,9 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment