"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "b9ae7ba5b7813bcc838e4e250653e2969b1ed812"
Commit 1efeccbd authored by Jing Zhang's avatar Jing Zhang
Browse files

add stream to all deviceOp using SetWorkspace

parent 4c5fe81e
......@@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
......@@ -354,7 +354,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
......@@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
......@@ -821,7 +821,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
......@@ -226,7 +226,9 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
}
}
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
......@@ -577,7 +577,9 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp
return (workspace_size);
};
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override
void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment