Commit fe6ee651 authored by fsx950223's avatar fsx950223
Browse files

add workspace size

parent 3c6a9b06
...@@ -316,7 +316,7 @@ int run(int argc, char* argv[]) ...@@ -316,7 +316,7 @@ int run(int argc, char* argv[])
std::vector<DeviceMemPtr> ygrad_tensors_device; std::vector<DeviceMemPtr> ygrad_tensors_device;
std::vector<DeviceMemPtr> kgrad_tensors_device; std::vector<DeviceMemPtr> kgrad_tensors_device;
std::vector<DeviceMemPtr> vgrad_tensors_device; std::vector<DeviceMemPtr> vgrad_tensors_device;
std::size_t group_count = 1; std::size_t group_count = 3;
std::size_t flop = 0, num_byte = 0; std::size_t flop = 0, num_byte = 0;
for(std::size_t i=0; i<group_count; i++){ for(std::size_t i=0; i<group_count; i++){
// int M = 128 * (rand() % 8 + 1); // int M = 128 * (rand() % 8 + 1);
...@@ -538,6 +538,17 @@ int run(int argc, char* argv[]) ...@@ -538,6 +538,17 @@ int run(int argc, char* argv[])
Scale{alpha}, Scale{alpha},
QKVElementOp{}, QKVElementOp{},
YElementOp{}); YElementOp{});
DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument));
gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer());
if(!gemm.IsSupportedArgument(argument))
{
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
......
...@@ -1062,6 +1062,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle ...@@ -1062,6 +1062,11 @@ struct DeviceGroupedMultiheadAttentionBackward_Xdl_CShuffle
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
return dynamic_cast<const Argument*>(p_arg)->group_count_ * sizeof(GroupKernelArg);
}
static auto MakeArgument(const std::vector<const DataType*>& p_As, static auto MakeArgument(const std::vector<const DataType*>& p_As,
const std::vector<const DataType*>& p_Bs, const std::vector<const DataType*>& p_Bs,
const std::vector<const DataType*>& p_B1s, const std::vector<const DataType*>& p_B1s,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment