Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
51a549c9
"tests/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "1e2a2c681de94d45ce997fbd8c4a6b162c091bd8"
Commit
51a549c9
authored
May 25, 2022
by
Jing Zhang
Browse files
moved hipMemAlloc outside of deviceOp
parent
d8f1458f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
53 additions
and
13 deletions
+53
-13
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+13
-4
include/ck/tensor_operation/gpu/device/device_gemm.hpp
include/ck/tensor_operation/gpu/device/device_gemm.hpp
+1
-0
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+39
-9
No files found.
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
51a549c9
...
@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
...
@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
exit
(
0
);
exit
(
0
);
}
}
int
group_count
=
4
;
int
group_count
=
rand
()
%
16
+
1
;
// GEMM shape
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmShape
>
gemm_shapes
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmShape
>
gemm_shapes
;
...
@@ -189,11 +189,20 @@ int main(int argc, char* argv[])
...
@@ -189,11 +189,20 @@ int main(int argc, char* argv[])
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_c
,
gemm_shapes
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
gemm_shapes
.
size
()));
// do GEMM
auto
argument
=
gemm
.
MakeArgument
(
p_a
,
p_b
,
p_c
,
gemm_shapes
,
gemm_desc_workspace
.
GetDeviceBuffer
(),
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
...
...
include/ck/tensor_operation/gpu/device/device_gemm.hpp
View file @
51a549c9
...
@@ -51,6 +51,7 @@ struct DeviceGroupedGemm : public BaseOperator
...
@@ -51,6 +51,7 @@ struct DeviceGroupedGemm : public BaseOperator
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
void
*
gemm_descs_args_workspace
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
51a549c9
...
@@ -350,6 +350,7 @@ struct DeviceGroupedGemmXdl
...
@@ -350,6 +350,7 @@ struct DeviceGroupedGemmXdl
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
void
*
gemm_descs_args_workspace
,
index_t
M01
,
index_t
M01
,
index_t
N01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
...
@@ -363,6 +364,8 @@ struct DeviceGroupedGemmXdl
...
@@ -363,6 +364,8 @@ struct DeviceGroupedGemmXdl
{
{
grid_size_
=
0
;
grid_size_
=
0
;
gemm_descs_args_workspace_
=
gemm_descs_args_workspace
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_shapes
.
size
());
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_shapes
.
size
());
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_a
.
size
())
&&
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_a
.
size
())
&&
...
@@ -437,6 +440,8 @@ struct DeviceGroupedGemmXdl
...
@@ -437,6 +440,8 @@ struct DeviceGroupedGemmXdl
std
::
vector
<
GemmDescKernelArg
>
gemm_desc_kernel_arg_
;
std
::
vector
<
GemmDescKernelArg
>
gemm_desc_kernel_arg_
;
void
*
gemm_descs_args_workspace_
;
index_t
grid_size_
;
index_t
grid_size_
;
};
};
...
@@ -485,12 +490,13 @@ struct DeviceGroupedGemmXdl
...
@@ -485,12 +490,13 @@ struct DeviceGroupedGemmXdl
}
}
}
}
void
*
gemm_descs_const_
;
// void* gemm_descs_args_workspace;
hipGetErrorString
(
hipMalloc
(
// hipGetErrorString(hipMalloc(
&
gemm_descs_const_
,
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmDescKernelArg
)));
// &gemm_descs_args_workspace, arg.gemm_desc_kernel_arg_.size() *
// sizeof(GemmDescKernelArg)));
hipGetErrorString
(
hipGetErrorString
(
hipMemcpy
(
gemm_descs_
const
_
,
hipMemcpy
(
arg
.
gemm_descs_
args_workspace
_
,
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmDescKernelArg
),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmDescKernelArg
),
hipMemcpyHostToDevice
));
hipMemcpyHostToDevice
));
...
@@ -515,7 +521,7 @@ struct DeviceGroupedGemmXdl
...
@@ -515,7 +521,7 @@ struct DeviceGroupedGemmXdl
dim3
(
arg
.
grid_size_
),
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
gemm_descs_
const
_
),
cast_pointer_to_constant_address_space
(
arg
.
gemm_descs_
args_workspace
_
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -539,7 +545,7 @@ struct DeviceGroupedGemmXdl
...
@@ -539,7 +545,7 @@ struct DeviceGroupedGemmXdl
dim3
(
arg
.
grid_size_
),
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
gemm_descs_
const
_
),
cast_pointer_to_constant_address_space
(
arg
.
gemm_descs_
args_workspace
_
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -581,11 +587,21 @@ struct DeviceGroupedGemmXdl
...
@@ -581,11 +587,21 @@ struct DeviceGroupedGemmXdl
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>
gemm_shapes
,
std
::
vector
<
GemmShape
>
gemm_shapes
,
void
*
gemm_descs_args_workspace
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
{
{
return
Argument
{
p_a
,
p_b
,
p_c
,
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
return
Argument
{
p_a
,
p_b
,
p_c
,
gemm_shapes
,
gemm_descs_args_workspace
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -595,13 +611,22 @@ struct DeviceGroupedGemmXdl
...
@@ -595,13 +611,22 @@ struct DeviceGroupedGemmXdl
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
const
void
*>&
p_b
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
void
*>&
p_c
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
std
::
vector
<
GemmShape
>&
gemm_shapes
,
void
*
gemm_descs_args_workspace
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
index_t
/* KBatch */
=
1
)
override
index_t
/* KBatch */
=
1
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_a
,
p_b
,
p_c
,
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
p_b
,
p_c
,
gemm_shapes
,
gemm_descs_args_workspace
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
}
}
// polymorphic
// polymorphic
...
@@ -632,6 +657,11 @@ struct DeviceGroupedGemmXdl
...
@@ -632,6 +657,11 @@ struct DeviceGroupedGemmXdl
return
str
.
str
();
return
str
.
str
();
}
}
static
size_t
GetWorkSpaceSize
(
const
index_t
group_count
)
{
return
group_count
*
sizeof
(
GemmDescKernelArg
);
}
};
};
}
// namespace device
}
// namespace device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment