Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
c7a96ed5
Unverified
Commit
c7a96ed5
authored
Jun 17, 2022
by
ltqin
Committed by
GitHub
Jun 16, 2022
Browse files
add p_workspace to baseargument (#275)
parent
6eb55499
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
32 deletions
+31
-32
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+7
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+24
-31
No files found.
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
c7a96ed5
...
@@ -15,6 +15,8 @@ struct BaseArgument
...
@@ -15,6 +15,8 @@ struct BaseArgument
BaseArgument
&
operator
=
(
const
BaseArgument
&
)
=
default
;
BaseArgument
&
operator
=
(
const
BaseArgument
&
)
=
default
;
virtual
~
BaseArgument
()
{}
virtual
~
BaseArgument
()
{}
void
*
p_workspace_
=
nullptr
;
};
};
struct
BaseInvoker
struct
BaseInvoker
...
@@ -42,7 +44,11 @@ struct BaseOperator
...
@@ -42,7 +44,11 @@ struct BaseOperator
virtual
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
)
const
{
return
0
;
}
virtual
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
)
const
{
return
0
;
}
virtual
void
SetWorkSpacePointer
(
BaseArgument
*
,
void
*
)
const
{}
virtual
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
)
const
final
{
assert
(
p_arg
);
p_arg
->
p_workspace_
=
p_workspace
;
}
virtual
~
BaseOperator
()
{}
virtual
~
BaseOperator
()
{}
};
};
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
c7a96ed5
...
@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdl
...
@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdl
{
{
grid_size_
=
0
;
grid_size_
=
0
;
gemm_descs_args
_workspace_
=
nullptr
;
p
_workspace_
=
nullptr
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_shapes
.
size
());
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_shapes
.
size
());
...
@@ -437,8 +437,6 @@ struct DeviceGroupedGemmXdl
...
@@ -437,8 +437,6 @@ struct DeviceGroupedGemmXdl
std
::
vector
<
GemmDescKernelArg
>
gemm_desc_kernel_arg_
;
std
::
vector
<
GemmDescKernelArg
>
gemm_desc_kernel_arg_
;
void
*
gemm_descs_args_workspace_
;
index_t
grid_size_
;
index_t
grid_size_
;
};
};
...
@@ -488,7 +486,7 @@ struct DeviceGroupedGemmXdl
...
@@ -488,7 +486,7 @@ struct DeviceGroupedGemmXdl
}
}
hipGetErrorString
(
hipGetErrorString
(
hipMemcpy
(
arg
.
gemm_descs_args
_workspace_
,
hipMemcpy
(
arg
.
p
_workspace_
,
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmDescKernelArg
),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmDescKernelArg
),
hipMemcpyHostToDevice
));
hipMemcpyHostToDevice
));
...
@@ -507,17 +505,17 @@ struct DeviceGroupedGemmXdl
...
@@ -507,17 +505,17 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation
,
CElementwiseOperation
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
stream_config
,
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
arg
.
gemm_descs_args
_workspace_
),
cast_pointer_to_constant_address_space
(
arg
.
p
_workspace_
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
arg
.
c_element_op_
);
}
}
else
else
{
{
...
@@ -531,17 +529,17 @@ struct DeviceGroupedGemmXdl
...
@@ -531,17 +529,17 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation
,
CElementwiseOperation
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
ave_time
=
stream_config
,
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
arg
.
gemm_descs_args
_workspace_
),
cast_pointer_to_constant_address_space
(
arg
.
p
_workspace_
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
arg
.
c_element_op_
);
}
}
return
ave_time
;
return
ave_time
;
...
@@ -635,11 +633,6 @@ struct DeviceGroupedGemmXdl
...
@@ -635,11 +633,6 @@ struct DeviceGroupedGemmXdl
{
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GemmDescKernelArg
);
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GemmDescKernelArg
);
}
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
workspace_ptr
)
const
override
{
dynamic_cast
<
Argument
*>
(
p_arg
)
->
gemm_descs_args_workspace_
=
workspace_ptr
;
}
};
};
}
// namespace device
}
// namespace device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment