Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
94e03cf5
"sims/vscode:/vscode.git/clone" did not exist on "475b21b4f76924ef812f90911dece401bd20b502"
Commit
94e03cf5
authored
May 24, 2023
by
Po-Yen, Chen
Browse files
Remove unnecessary kernel arguments
parent
d23a7617
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
151 additions
and
179 deletions
+151
-179
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
...pu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
+99
-134
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+52
-45
No files found.
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
View file @
94e03cf5
...
...
@@ -20,6 +20,50 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
namespace
detail
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
c_grid_desc_m_n
,
index_t
NumKBlockLoop
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
,
NumKBlockLoop
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_m_n
;
ignore
=
NumKBlockLoop
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
// namespace detail
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
...
...
@@ -967,12 +1011,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
{
0
,
0
,
0
});
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
...
...
@@ -980,9 +1018,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
...
...
@@ -1014,6 +1049,15 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
7
,
// CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector
>
;
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
std
::
declval
<
ABCGridDescs
>
()[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
std
::
declval
<
ABCGridDescs
>
()[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
std
::
declval
<
ABCGridDescs
>
()[
I2
])
>
;
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
std
::
declval
<
CGridDesc_M_N
>
()));
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -1030,19 +1074,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
ck
::
index_t
M01
)
:
p_a_grid_
{
p_out_grid
},
p_b_grid_
{
p_wei_grid
},
p_c_grid_
{
p_in_grid
},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
out_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
in_element_op
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_C_
{
C
},
...
...
@@ -1092,17 +1128,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
auto
block_2_ctile_map
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
],
M01_
,
N01_
);
if
(
GridwiseGemm
::
CheckValidity
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
block_2_ctile_map
))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
.
push_back
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
block_2_ctile_map
);
}
}
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
...
...
@@ -1150,18 +1175,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
auto
block_2_ctile_map
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
],
M01_
,
N01_
);
if
(
GridwiseGemm
::
CheckValidity
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
block_2_ctile_map
))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
.
push_back
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
block_2_ctile_map
);
}
}
}
}
...
...
@@ -1218,19 +1231,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
auto
block_2_ctile_map
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
],
M01_
,
N01_
);
if
(
GridwiseGemm
::
CheckValidity
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
block_2_ctile_map
))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
.
push_back
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
block_2_ctile_map
);
}
}
}
}
...
...
@@ -1242,14 +1242,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std
::
vector
<
AGridDesc_K0_M_K1
>
a_grid_desc_k0_m_k1_container_
;
std
::
vector
<
BGridDesc_K0_N_K1
>
b_grid_desc_k0_n_k1_container_
;
std
::
vector
<
CGridDesc_M_N
>
c_grid_desc_m_n_container_
;
std
::
vector
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
;
std
::
vector
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
block_2_ctile_map_container_
;
index_t
M01_
;
index_t
N01_
;
OutElementwiseOperation
a_element_op_
;
WeiElementwiseOperation
b_element_op_
;
InElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_N_
;
index_t
Conv_K_
;
...
...
@@ -1315,39 +1308,34 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m_n_container_
[
i
],
arg
.
block_2_ctile_map_container_
[
i
]))
arg
.
c_grid_desc_m_n_container_
[
i
]))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_container_
[
i
].
CalculateGridSize
(
arg
.
c_grid_desc_m_n_container_
[
i
]);
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_container_
[
i
].
GetLength
(
I0
),
arg
.
c_grid_desc_m_n_container_
[
i
].
GetLength
(
I1
));
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I2
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
const
auto
kernel
=
detail
::
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
OutElementwiseOperation
,
WeiElementwiseOperation
,
InElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
true
>
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
g
rid_size
),
dim3
(
g
dx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
...
...
@@ -1355,32 +1343,24 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
],
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_container_
[
i
]);
arg
.
c_grid_desc_m_n_container_
[
i
],
GridwiseGemm
::
CalculateNumKBlockLoop
(
K
));
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
const
auto
kernel
=
detail
::
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
OutElementwiseOperation
,
WeiElementwiseOperation
,
InElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
false
>
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
g
rid_size
),
dim3
(
g
dx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
...
...
@@ -1388,11 +1368,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
],
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_container_
[
i
]);
arg
.
c_grid_desc_m_n_container_
[
i
],
GridwiseGemm
::
CalculateNumKBlockLoop
(
K
));
}
}
return
ave_time
;
...
...
@@ -1446,8 +1423,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m_n_container_
[
i
],
arg
.
block_2_ctile_map_container_
[
i
]))
arg
.
c_grid_desc_m_n_container_
[
i
]))
{
return
false
;
}
...
...
@@ -1472,10 +1448,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
...
...
@@ -1490,11 +1463,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
1
,
1
,
in_element_op
,
wei_element_op
,
out_element_op
};
1
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -1513,9 +1482,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
override
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
InDataType
*>
(
p_in_grid
),
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
...
...
@@ -1530,11 +1499,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
1
,
1
,
in_element_op
,
wei_element_op
,
out_element_op
);
1
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
94e03cf5
...
...
@@ -17,7 +17,7 @@
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
bool
HasMainKBlockLoop
>
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
>
#ifdef USE_WAVES_PER_EU
__attribute__
((
amdgpu_waves_per_eu
(
1
,
1
)))
#endif
...
...
@@ -45,7 +45,7 @@ __global__ void
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
,
karg
);
karg
.
NumKBlockLoop
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -128,21 +128,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
struct
Problem
{
__host__
Argument
(
const
FloatAB
*
p_a_grid_
,
const
FloatAB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
__host__
Problem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
},
M
{
M_
},
:
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
...
...
@@ -157,7 +151,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__
void
Print
()
const
{
std
::
cout
<<
"
arg
{"
std
::
cout
<<
"
problem
{"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
...
...
@@ -170,9 +164,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<<
"NumKBlockLoop: "
<<
NumKBlockLoop
<<
"}"
<<
std
::
endl
;
}
const
FloatAB
*
p_a_grid
;
const
FloatAB
*
p_b_grid
;
FloatC
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
...
...
@@ -185,6 +176,30 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
index_t
NumKBlockLoop
;
};
// Argument
struct
Argument
:
public
Problem
,
public
tensor_operation
::
device
::
BaseArgument
{
__host__
Argument
(
const
FloatAB
*
p_a_grid_
,
const
FloatAB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
Problem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
},
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
}
{
}
const
FloatAB
*
p_a_grid
;
const
FloatAB
*
p_b_grid
;
FloatC
*
p_c_grid
;
};
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
...
...
@@ -260,15 +275,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
}
template
<
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
Block2CTileMap
>
template
<
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
...
...
@@ -290,17 +301,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
false
;
}
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
__host__
static
constexpr
bool
CheckValidity
(
const
Problem
&
problem
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
...
...
@@ -310,7 +316,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
"Invalid tuning param!"
);
// check gridwise gemm pipeline
const
index_t
K0
=
karg
.
K
/
K1Value
;
const
index_t
K0
=
problem
.
K
/
K1Value
;
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
...
...
@@ -394,7 +400,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Argument
&
karg
)
index_t
NumKBlockLoop
)
{
#if ENABLE_DUMP_CLOCK
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -417,7 +423,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
const
auto
block_2_ctile_map
=
Block2CTileMap
{
karg
.
M
,
karg
.
N
};
const
auto
block_2_ctile_map
=
Block2CTileMap
{
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
)};
// divide block work by [M, N]
const
auto
block_work_idx
=
...
...
@@ -546,7 +553,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// gridwise GEMM pipeline
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
karg
.
NumKBlockLoop
);
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
NumKBlockLoop
);
#if ENABLE_DUMP_CLOCK
long
loop_start
=
0
,
loop_end
=
0
;
...
...
@@ -790,8 +797,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
LoopSched
,
PipelineVer
>
;
using
typename
Parent
::
Argument
;
using
typename
Parent
::
GridwiseGemmPipe
;
using
typename
Parent
::
Problem
;
using
Parent
::
I1
;
...
...
@@ -899,7 +906,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
__host__
static
constexpr
bool
CheckValidity
(
const
Problem
&
problem
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
...
...
@@ -913,7 +920,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
if
(
!
(
problem
.
M
%
MPerBlock
==
0
))
{
return
false
;
}
...
...
@@ -924,7 +931,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
if
(
!
(
problem
.
N
%
NPerBlock
==
0
))
{
return
false
;
}
...
...
@@ -932,14 +939,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
problem
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
problem
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
...
...
@@ -947,21 +954,21 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
problem
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
problem
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
// check gridwise gemm pipeline
const
index_t
K0
=
karg
.
K
/
K1
;
const
index_t
K0
=
problem
.
K
/
K1
;
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment