Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
65d67fb7
Commit
65d67fb7
authored
Mar 11, 2022
by
Jing Zhang
Browse files
add ptr to GemmDesc
parent
f9b740b5
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
124 additions
and
148 deletions
+124
-148
example/14_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/14_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+30
-19
include/ck/config.hpp
include/ck/config.hpp
+2
-2
include/ck/tensor_operation/gpu/device/device_gemm.hpp
include/ck/tensor_operation/gpu/device/device_gemm.hpp
+1
-4
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+86
-113
include/ck/tensor_operation/gpu/grid/gridwise_grouped_gemm_xdlops_v2r3.hpp
..._operation/gpu/grid/gridwise_grouped_gemm_xdlops_v2r3.hpp
+5
-10
No files found.
example/14_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
65d67fb7
...
...
@@ -79,21 +79,21 @@ int main(int argc, char* argv[])
int
group_count
=
4
;
// GEMM shape
std
::
vector
<
ck
::
g
emm
_desc
>
gemm_shapes
;
std
::
vector
<
ck
::
G
emm
Shape
>
gemm_shapes
;
int
A_size
=
0
,
B_size
=
0
,
C_size
=
0
;
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
int
M
=
3840
;
int
N
=
1
024
;
int
K
=
4096
;
int
M
=
256
+
256
*
i
;
int
N
=
1
28
+
128
*
i
;
int
K
=
64
+
64
*
i
;
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
,
A_size
,
B_size
,
C_size
});
gemm_shapes
.
push_back
({
M
,
N
,
K
,
K
,
K
,
N
,
nullptr
,
nullptr
,
nullptr
});
A_size
+=
M
*
K
;
B_size
+=
N
*
K
;
C_size
+=
M
*
N
;
A_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
K
;
B_size
+=
gemm_shapes
[
i
].
N
*
gemm_shapes
[
i
].
K
;
C_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
N
;
}
auto
f_host_tensor_descriptor
=
...
...
@@ -163,12 +163,27 @@ int main(int argc, char* argv[])
std
::
vector
<
ADataType
>
a_tensors_data
,
b_tensors_data
,
c_tensors_data
;
A_size
=
0
;
B_size
=
0
;
C_size
=
0
;
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
a_tensors_data
.
insert
(
a_tensors_data
.
end
(),
a_tensors
[
i
].
mData
.
begin
(),
a_tensors
[
i
].
mData
.
end
());
b_tensors_data
.
insert
(
b_tensors_data
.
end
(),
b_tensors
[
i
].
mData
.
begin
(),
b_tensors
[
i
].
mData
.
end
());
gemm_shapes
[
i
].
p_a
=
static_cast
<
ADataType
*>
(
a_tensors_device_buf
.
GetDeviceBuffer
())
+
A_size
;
gemm_shapes
[
i
].
p_b
=
static_cast
<
BDataType
*>
(
b_tensors_device_buf
.
GetDeviceBuffer
())
+
B_size
;
gemm_shapes
[
i
].
p_c
=
static_cast
<
CDataType
*>
(
c_tensors_device_buf
.
GetDeviceBuffer
())
+
C_size
;
A_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
K
;
B_size
+=
gemm_shapes
[
i
].
N
*
gemm_shapes
[
i
].
K
;
C_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
N
;
}
a_tensors_device_buf
.
ToDevice
(
a_tensors_data
.
data
());
...
...
@@ -179,16 +194,9 @@ int main(int argc, char* argv[])
auto
c_element_op
=
CElementOp
{};
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_tensors_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_tensors_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_tensors_device_buf
.
GetDeviceBuffer
()),
gemm_shapes
,
a_element_op
,
b_element_op
,
c_element_op
);
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
gemm_shapes
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
...
...
@@ -210,11 +218,14 @@ int main(int argc, char* argv[])
c_tensors_device_buf
.
FromDevice
(
c_tensors_data
.
data
());
C_size
=
0
;
for
(
int
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
memcpy
(
c_device_tensors
[
i
].
mData
.
data
(),
c_tensors_data
.
data
()
+
gemm_shapes
[
i
].
OffsetC
,
c_tensors_data
.
data
()
+
C_size
,
c_device_tensors
[
i
].
mData
.
size
()
*
sizeof
(
CDataType
));
C_size
+=
gemm_shapes
[
i
].
M
*
gemm_shapes
[
i
].
N
;
}
if
(
do_verification
)
...
...
include/ck/config.hpp
View file @
65d67fb7
...
...
@@ -177,11 +177,11 @@ enum ActivTypeEnum_t
using
index_t
=
int32_t
;
using
long_index_t
=
int64_t
;
struct
g
emm
_desc
struct
G
emm
Shape
{
ck
::
index_t
M
,
N
,
K
;
ck
::
index_t
StrideA
,
StrideB
,
StrideC
;
ck
::
index_t
OffsetA
,
OffsetB
,
OffsetC
;
void
*
p_a
,
*
p_b
,
*
p_c
;
};
}
// namespace ck
...
...
include/ck/tensor_operation/gpu/device/device_gemm.hpp
View file @
65d67fb7
...
...
@@ -64,10 +64,7 @@ template <typename AElementwiseOperation,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
std
::
vector
<
gemm_desc
>
gemm_shapes
,
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
GemmShape
>
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
65d67fb7
...
...
@@ -233,99 +233,84 @@ struct DeviceGroupedGemmXdl
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
ck
::
index_t
OffsetA
,
OffsetB
,
OffsetC
;
const
ADataType
*
a_ptr
;
const
BDataType
*
b_ptr
;
CDataType
*
c_ptr
;
ck
::
index_t
BlockStart
,
BlockEnd
;
};
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
std
::
vector
<
gemm_desc
>
gemm_shapes
,
Argument
(
std
::
vector
<
GemmShape
>
gemm_shapes
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
gemm_shapes_
{
gemm_shapes
},
M01_
{
M01
},
:
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
grid_size
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
i
<
gemm_shapes_
.
size
())
{
const
index_t
M
=
gemm_shapes_
[
i
].
M
;
const
index_t
N
=
gemm_shapes_
[
i
].
N
;
const
index_t
K
=
gemm_shapes_
[
i
].
K
;
const
index_t
StrideA
=
gemm_shapes_
[
i
].
StrideA
;
const
index_t
StrideB
=
gemm_shapes_
[
i
].
StrideB
;
const
index_t
StrideC
=
gemm_shapes_
[
i
].
StrideC
;
gemm_desc_
(
i
).
a_grid_desc_k0_m_k1_
=
DeviceGroupedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
gemm_desc_
(
i
).
b_grid_desc_k0_n_k1_
=
DeviceGroupedGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
gemm_desc_
(
i
).
c_grid_desc_m_n_
=
DeviceGroupedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
index_t
grid_size_grp
=
GridwiseGemm
::
CalculateGridSize
(
gemm_desc_
[
i
].
c_grid_desc_m_n_
);
gemm_desc_
(
i
).
BlockStart
=
grid_size
;
gemm_desc_
(
i
).
BlockEnd
=
grid_size
+
grid_size_grp
;
grid_size
+=
grid_size_grp
;
gemm_desc_
(
i
).
OffsetA
=
gemm_shapes_
[
i
].
OffsetA
;
gemm_desc_
(
i
).
OffsetB
=
gemm_shapes_
[
i
].
OffsetB
;
gemm_desc_
(
i
).
OffsetC
=
gemm_shapes_
[
i
].
OffsetC
;
if
(
GridwiseGemm
::
CheckValidity
(
gemm_desc_
[
i
].
a_grid_desc_k0_m_k1_
,
gemm_desc_
[
i
].
b_grid_desc_k0_n_k1_
,
gemm_desc_
[
i
].
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
gemm_desc_
(
i
).
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
gemm_desc_
[
i
].
c_grid_desc_m_n_
);
for
(
index_t
i
=
0
;
i
<
gemm_shapes
.
size
();
i
++
)
{
const
index_t
M
=
gemm_shapes
[
i
].
M
;
const
index_t
N
=
gemm_shapes
[
i
].
N
;
const
index_t
K
=
gemm_shapes
[
i
].
K
;
gemm_desc_
(
i
).
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
gemm_desc_
[
i
].
c_grid_desc_m_n_
,
M01
,
N01
);
}
}
else
const
index_t
StrideA
=
gemm_shapes
[
i
].
StrideA
;
const
index_t
StrideB
=
gemm_shapes
[
i
].
StrideB
;
const
index_t
StrideC
=
gemm_shapes
[
i
].
StrideC
;
const
auto
a_grid_desc_k0_m_k1_
=
DeviceGroupedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
const
auto
b_grid_desc_k0_n_k1_
=
DeviceGroupedGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
const
auto
c_grid_desc_m_n_
=
DeviceGroupedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
index_t
grid_size_grp
=
GridwiseGemm
::
CalculateGridSize
(
c_grid_desc_m_n_
);
const
index_t
BlockStart
=
grid_size
;
const
index_t
BlockEnd
=
grid_size
+
grid_size_grp
;
grid_size
+=
grid_size_grp
;
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
gemm_desc_
(
i
).
BlockStart
=
-
1
;
gemm_desc_
(
i
).
BlockEnd
=
-
1
;
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
const
auto
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
GemmShape_
.
push_back
(
GemmDesc
{
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
block_2_ctile_map_
,
static_cast
<
const
ADataType
*>
(
gemm_shapes
[
i
].
p_a
),
static_cast
<
const
BDataType
*>
(
gemm_shapes
[
i
].
p_b
),
static_cast
<
CDataType
*>
(
gemm_shapes
[
i
].
p_c
),
BlockStart
,
BlockEnd
});
}
}
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
std
::
vector
<
gemm_desc
>
gemm_shapes_
;
index_t
M01_
;
index_t
N01_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_desc
_
;
std
::
vector
<
GemmDesc
>
GemmShape
_
;
index_t
grid_size
;
};
...
...
@@ -337,44 +322,51 @@ struct DeviceGroupedGemmXdl
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
GemmShape_arg
;
bool
has_main_k0_block_loop
=
true
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
i
<
arg
.
g
emm
_s
hape
s
_
.
size
())
if
(
i
<
arg
.
G
emm
S
hape_
.
size
())
{
GemmShape_arg
(
i
)
=
arg
.
GemmShape_
[
i
];
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
gemm_desc_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
GemmShape_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
GemmShape_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
GemmShape_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
gemm_desc_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
gemm_desc_
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
GemmShape_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
GemmShape_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
GemmShape_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
gemm_desc_
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
gemm_desc_
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
GemmShape_arg
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
GemmShape_arg
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"Block: "
<<
arg
.
gemm_desc_
[
i
].
BlockStart
<<
", "
<<
arg
.
gemm_desc_
[
i
].
BlockEnd
<<
std
::
endl
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
gemm_desc_
[
i
].
a_grid_desc_k0_m_k1_
,
arg
.
gemm_desc_
[
i
].
b_grid_desc_k0_n_k1_
,
arg
.
gemm_desc_
[
i
].
c_grid_desc_m_n_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
GemmShape_arg
[
i
].
a_grid_desc_k0_m_k1_
,
GemmShape_arg
[
i
].
b_grid_desc_k0_n_k1_
,
GemmShape_arg
[
i
].
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
}
});
const
auto
K0
=
arg
.
gemm_desc_
[
I0
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
auto
K0
=
GemmShape_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
if
(
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
)
!=
has_main_k0_block_loop
)
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
}
});
float
ave_time
=
0
;
...
...
@@ -396,11 +388,8 @@ struct DeviceGroupedGemmXdl
dim3
(
arg
.
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
gemm_desc_
,
arg
.
gemm_shapes_
.
size
(),
GemmShape_arg
,
arg
.
GemmShape_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
...
...
@@ -423,11 +412,8 @@ struct DeviceGroupedGemmXdl
dim3
(
arg
.
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
gemm_desc_
,
arg
.
gemm_shapes_
.
size
(),
GemmShape_arg
,
arg
.
GemmShape_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
...
...
@@ -451,9 +437,9 @@ struct DeviceGroupedGemmXdl
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
g
emm
_desc_
[
Number
<
0
>
{}
].
a_grid_desc_k0_m_k1_
,
arg
.
g
emm
_desc_
[
Number
<
0
>
{}
].
b_grid_desc_k0_n_k1_
,
arg
.
g
emm
_desc_
[
Number
<
0
>
{}
].
c_grid_desc_m_n_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
G
emm
Shape_
[
0
].
a_grid_desc_k0_m_k1_
,
arg
.
G
emm
Shape_
[
0
].
b_grid_desc_k0_n_k1_
,
arg
.
G
emm
Shape_
[
0
].
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
}
...
...
@@ -464,38 +450,25 @@ struct DeviceGroupedGemmXdl
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
std
::
vector
<
gemm_desc
>
gemm_shapes
,
static
auto
MakeArgument
(
std
::
vector
<
GemmShape
>
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
return
Argument
{
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
std
::
vector
<
gemm_desc
>
gemm_shapes
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
GemmShape
>
gemm_shapes
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
/* KBatch */
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
return
std
::
make_unique
<
Argument
>
(
gemm_shapes
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_grouped_gemm_xdlops_v2r3.hpp
View file @
65d67fb7
...
...
@@ -26,9 +26,6 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdlops_v2r3
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_desc_
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
...
...
@@ -41,18 +38,16 @@ __global__ void
#if 1
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
)
if
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
&&
i
<
group_count
)
{
auto
group_id
=
i
;
const
index_t
block_id_grp
=
block_id
-
gemm_desc_
[
group_id
].
BlockStart
;
const
index_t
a_offset_grp
=
gemm_desc_
[
group_id
].
OffsetA
;
const
index_t
b_offset_grp
=
gemm_desc_
[
group_id
].
OffsetB
;
const
index_t
c_offset_grp
=
gemm_desc_
[
group_id
].
OffsetC
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_offset_grp
,
p_b_grid
+
b_offset_grp
,
p_
c_gr
id
+
c_offset_grp
,
gemm_desc_
[
group_id
].
a_ptr
,
gemm_desc_
[
group_id
].
b_ptr
,
gemm_des
c_
[
gr
oup_id
].
c_ptr
,
p_shared
,
gemm_desc_
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_
[
group_id
].
b_grid_desc_k0_n_k1_
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment