Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bee4e344
"driver/src/conv_driver.cpp" did not exist on "bd7a2300061e26092edfe28c605c32101f0ec9e8"
Commit
bee4e344
authored
May 19, 2023
by
aska-0096
Browse files
(5/5) attention pass, todo: debug lds perf bug
parent
fd4ff3a7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
276 additions
and
248 deletions
+276
-248
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
...ple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+48
-49
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+36
-25
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+18
-19
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+15
-15
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+130
-117
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+27
-21
No files found.
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
View file @
bee4e344
...
...
@@ -53,8 +53,8 @@ using DeviceConvFwdInstance =
GemmSpec
,
// GemmSpecialization
1
,
// Prefetch stage
128
,
// BlockSize
64
,
// MPerBlock
64
,
// NPerBlock
64
,
// MPerBlock
64
,
// NPerBlock
64
,
// KPerBlock
4
,
// K1
16
,
// MPerWMMA
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
bee4e344
...
...
@@ -305,7 +305,7 @@ struct BlockwiseGemmWMMA
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -365,58 +365,57 @@ struct BlockwiseGemmWMMA
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of
// k=0,kpack*1, ... read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
/
B_KRow
,
n0
,
0
,
(
i
/
B_K1
)
%
B_KRow
,
0
,
i
%
B_K1
))
>
{}];
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
/
A_KRow
,
m0
,
0
,
(
i
/
A_K1
)
%
A_KRow
,
0
,
i
%
A_K1
))
>
{}];
});
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
/
B_KRow
,
n0
,
0
,
(
i
/
B_K1
)
%
B_KRow
,
0
,
i
%
B_K1
))
>
{}];
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
/
A_KRow
,
m0
,
0
,
(
i
/
A_K1
)
%
A_KRow
,
0
,
i
%
A_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
bee4e344
...
...
@@ -136,6 +136,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
WmmaK
=
16
;
...
...
@@ -175,13 +176,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
else
{
return
Transform
::
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{},
Number
<
K1
>
{});
return
Transform
::
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{},
Number
<
K1
>
{});
}
}
...
...
@@ -197,14 +200,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
else
{
return
Transform
::
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{},
Number
<
K1
>
{});
return
Transform
::
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{},
Number
<
K1
>
{});
}
}
...
...
@@ -220,14 +224,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
else
{
return
Transform
::
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{},
Number
<
L1
>
{});
return
Transform
::
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{},
Number
<
L1
>
{});
}
}
...
...
@@ -521,7 +526,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
else
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I3
)
*
arg
.
a_grid_desc
.
GetLength
(
I
5
);
arg
.
a_grid_desc
.
GetLength
(
I
4
)
*
arg
.
a_grid_desc
.
GetLength
(
I6
);
}
}();
...
...
@@ -826,7 +831,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
<<
" NumPrefetch: "
<<
" AEnableLds: "
<<
AEnableLds
<<
", "
<<
"B0EnableLds: "
<<
B0EnableLds
<<
", "
<<
"B1EnableLds: "
<<
B1EnableLds
<<
", "
<<
"NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
bee4e344
...
...
@@ -468,26 +468,25 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
has_main_k_block_loop
>
;
// Last Option is W/O
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_grid_desc
,
arg
.
b_grid_desc
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
block_2_ctile_map_
);
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_grid_desc
,
arg
.
b_grid_desc
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
block_2_ctile_map_
);
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
bee4e344
...
...
@@ -398,21 +398,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
has_main_k_block_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
bee4e344
...
...
@@ -243,10 +243,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
AK1
;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
AK1
),
make_tuple
(
Number
<
MRepeat
>
{}
*
AK1
,
AK1
,
AK1
,
AK1
,
AK1
,
I1
));
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
MRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
AK1
),
make_tuple
(
Number
<
MRepeat
>
{}
*
Number
<
K0PerWmma
>
{}
*
AK1
,
Number
<
K0PerWmma
>
{}
*
AK1
,
Number
<
K0PerWmma
>
{}
*
AK1
,
AK1
,
AK1
,
AK1
,
I1
));
}
}();
...
...
@@ -277,10 +290,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
// KWmma->NRepeat->NWave->NRow->NPerWmma->BK1 Per Thread
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
BK1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
LRepeat
>
{},
I1
,
I1
,
I1
,
BK1
),
make_tuple
(
Number
<
LRepeat
>
{}
*
BK1
,
BK1
,
BK1
,
BK1
,
BK1
,
I1
));
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
LRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
BK1
),
make_tuple
(
Number
<
LRepeat
>
{}
*
Number
<
K0PerWmma
>
{}
*
BK1
,
Number
<
K0PerWmma
>
{}
*
BK1
,
Number
<
K0PerWmma
>
{}
*
BK1
,
BK1
,
BK1
,
BK1
,
I1
));
}
}();
...
...
@@ -310,10 +336,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else
{
constexpr
auto
LWmmaPerblock
=
LPerBlock
/
WmmaL
;
// LWmma->NRepeat->NWave->NRow->LPerWmma->BL1 Per Thread
constexpr
auto
L0PerWmma
=
WmmaL
/
2
/
BL1
;
// LWmma->NRepeat->MWave->L0PerWmma->LRow->MPerWmma->L1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
LWmmaPerblock
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
BL1
),
make_tuple
(
Number
<
NRepeat
>
{}
*
BL1
,
BL1
,
BL1
,
BL1
,
BL1
,
I1
));
make_tuple
(
Number
<
LWmmaPerblock
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
L0PerWmma
>
{},
I1
,
I1
,
BL1
),
make_tuple
(
Number
<
NRepeat
>
{}
*
Number
<
L0PerWmma
>
{}
*
BL1
,
Number
<
L0PerWmma
>
{}
*
BL1
,
Number
<
L0PerWmma
>
{}
*
BL1
,
BL1
,
BL1
,
BL1
,
I1
));
}
}();
...
...
@@ -333,7 +372,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
,
0
);
}
}();
...
...
@@ -353,7 +392,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
,
0
);
}
}();
...
...
@@ -371,7 +410,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{
constexpr
auto
LWmmaPerBlock
=
LTilePerBlock
/
WmmaL
;
return
make_multi_index
(
LWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
LWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
,
0
);
}
}();
...
...
@@ -387,44 +426,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if
constexpr
(
AEnableLds
)
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_KRow
=
I1
;
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_
pass_through_transform
(
Number
<
A_K0
>
{}),
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{}
,
A_KRow
)
),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
else
{
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I5
);
// Workaround, Freeze transform
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_freeze_transform
(
I0
),
make_pass_through_transform
(
Number
<
KWmma
>
{}),
make_pass_through_transform
(
Number
<
MRepeat
>
{}),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
K0PerWmma
=
ABlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
A_KRow
=
ABlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I6
);
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
Number
<
MRepeat
>
{},
I1
,
Number
<
A_KRow
>
{},
I1
,
Number
<
A_K1
>
{}));
}
}();
...
...
@@ -439,44 +466,33 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if
constexpr
(
B0EnableLds
)
{
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr
auto
B_K0
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_K0
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_KRow
=
I1
;
return
transform_tensor_descriptor
(
B0BlockDesc_
{},
make_tuple
(
make_
pass_through_transform
(
Number
<
B_K0
>
{}),
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{}
,
B_KRow
)
),
make_unmerge_transform
(
make_tuple
(
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
else
{
// KWmma_LRepeat_LWave_KRow_LPerWmma_K1 -> K0_LRepeat_Lwaves_LPerWmma_K1
constexpr
auto
KWmma
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I5
);
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
K0PerWmma
=
B0BlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
B_KRow
=
B0BlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I6
);
// Workaround, Freeze transform
return
transform_tensor_descriptor
(
B0BlockDesc_
{},
make_tuple
(
make_freeze_transform
(
I0
),
make_pass_through_transform
(
Number
<
KWmma
>
{}),
make_pass_through_transform
(
Number
<
LRepeat
>
{}),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
Number
<
LRepeat
>
{},
I1
,
Number
<
B_KRow
>
{},
I1
,
Number
<
B_K1
>
{}));
}
}();
...
...
@@ -489,14 +505,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{
constexpr
index_t
A_L0
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I0
);
constexpr
index_t
A_L1
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I2
);
constexpr
auto
A_LRow
=
I1
;
return
transform_tensor_descriptor
(
A1BlockDesc_AL0_M_AL1
{},
make_tuple
(
make_
pass_through_transform
(
Number
<
A_L0
>
{}),
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
Number
<
A_L0
>
{}
,
A_LRow
)
),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
)),
make_pass_through_transform
(
Number
<
A_L1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
template
<
typename
B1BlockDesc_
>
...
...
@@ -507,44 +523,31 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
if
constexpr
(
B1EnableLds
)
{
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr
auto
B_L0
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_L0
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_LRow
=
I1
;
return
transform_tensor_descriptor
(
B1BlockDesc_
{},
make_tuple
(
make_
pass_through_transform
(
Number
<
B_L0
>
{}),
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
Number
<
B_L0
>
{}
,
B_LRow
)
),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_L1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
else
{
// LWmma_NRepeat_NWave_LRow_NPerWmma_L1 -> L0_NRepeat_Nwaves_NPerWmma_L1
constexpr
auto
LWmma
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I5
);
// Workaround, Freeze transform
return
transform_tensor_descriptor
(
B1BlockDesc_
{},
make_tuple
(
make_freeze_transform
(
I0
),
make_pass_through_transform
(
Number
<
LWmma
>
{}),
make_pass_through_transform
(
Number
<
NRepeat
>
{}),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
B_L1
>
{})),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
constexpr
auto
LWmma
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
L0PerWmma
=
B1BlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
B_LRow
=
B1BlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I6
);
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
LWmma
*
L0PerWmma
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
B_LRow
>
{},
I1
,
Number
<
B_L1
>
{}));
}
}();
...
...
@@ -610,9 +613,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else
{
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
)
*
a_grid_desc
.
GetLength
(
I2
)
*
a_grid_desc
.
GetLength
(
I
4
),
a_grid_desc
.
GetLength
(
I
5
),
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I
5
));
a_grid_desc
.
GetLength
(
I
4
)
*
a_grid_desc
.
GetLength
(
I6
));
}
};
...
...
@@ -625,9 +628,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else
{
return
make_tuple
(
b0_grid_desc
.
GetLength
(
I1
)
*
b0_grid_desc
.
GetLength
(
I2
)
*
b0_grid_desc
.
GetLength
(
I
4
),
b0_grid_desc
.
GetLength
(
I
5
),
b0_grid_desc
.
GetLength
(
I0
)
*
b0_grid_desc
.
GetLength
(
I3
)
*
b0_grid_desc
.
GetLength
(
I
5
));
b0_grid_desc
.
GetLength
(
I
4
)
*
b0_grid_desc
.
GetLength
(
I6
));
}
};
...
...
@@ -640,9 +643,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else
{
return
make_tuple
(
b1_grid_desc
.
GetLength
(
I1
)
*
b1_grid_desc
.
GetLength
(
I2
)
*
b1_grid_desc
.
GetLength
(
I
4
),
b1_grid_desc
.
GetLength
(
I
5
),
b1_grid_desc
.
GetLength
(
I0
)
*
b1_grid_desc
.
GetLength
(
I3
)
*
b1_grid_desc
.
GetLength
(
I
5
));
b1_grid_desc
.
GetLength
(
I
4
)
*
b1_grid_desc
.
GetLength
(
I6
));
}
};
...
...
@@ -884,6 +887,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
a_block_desc
.
GetElementSpaceSize
());
...
...
@@ -896,11 +900,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Number
<
MRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
ABlockTransferSrcScalarPerVector
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
...
...
@@ -908,6 +913,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
make_multi_index
(
0
,
m_block_data_idx_on_grid
/
(
MWaves
*
MPerWmma
),
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
0
));
...
...
@@ -960,6 +966,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy
// KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
auto
b0_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
B0DataType
>
(
b0_block_desc
.
GetElementSpaceSize
());
...
...
@@ -972,11 +979,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Number
<
LRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
B0BlockTransferSrcScalarPerVector
,
B0ThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
...
...
@@ -984,6 +992,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
make_multi_index
(
0
,
0
/
(
LWaves
*
LPerWmma
),
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
0
));
...
...
@@ -1054,7 +1063,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
return
make_multi_index
(
-
a_grid_desc
.
GetLength
(
I0
),
0
,
0
);
}
else
{
return
make_multi_index
(
-
a_grid_desc
.
GetLength
(
I0
),
0
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
-
a_grid_desc
.
GetLength
(
I0
),
0
,
0
,
0
,
0
,
0
,
0
);
}
}();
...
...
@@ -1063,7 +1072,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
return
make_multi_index
(
-
b0_grid_desc
.
GetLength
(
I0
),
LPerBlock
,
0
);
}
else
{
return
make_multi_index
(
-
b0_grid_desc
.
GetLength
(
I0
),
LRepeat
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
-
b0_grid_desc
.
GetLength
(
I0
),
LRepeat
,
0
,
0
,
0
,
0
,
0
);
}
}();
...
...
@@ -1072,7 +1081,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
);
}
else
{
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I5
);
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I4
)
*
a_grid_desc
.
GetLength
(
I6
);
}
}();
...
...
@@ -1208,6 +1218,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr
auto
LWmmaPerBlock
=
LTilePerBlock
/
WmmaL
;
constexpr
auto
L0PerWmma
=
WmmaL
/
2
/
L1Value
;
auto
b1_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
B1DataType
>
(
b1_block_desc
.
GetElementSpaceSize
());
...
...
@@ -1220,11 +1231,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Sequence
<
Number
<
LWmmaPerBlock
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
L0PerWmma
>
{},
I1
,
I1
,
Number
<
L1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
6
,
B1BlockTransferSrcScalarPerVector
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
...
...
@@ -1232,6 +1244,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
make_multi_index
(
0
,
n_block_data_idx_on_grid
/
(
NWaves
*
NPerWmma
),
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
0
));
...
...
@@ -1262,7 +1275,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
KPack
,
false
,
B1EnableLds
,
true
>
{
make_tuple
(
0
,
0
,
0
,
0
,
0
)};
true
>
{
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
)};
auto
acc1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
...
...
@@ -1271,7 +1284,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
return
b0_grid_desc
.
GetLength
(
I1
);
}
else
{
return
b0_grid_desc
.
GetLength
(
I1
)
*
b0_grid_desc
.
GetLength
(
I2
)
*
b0_grid_desc
.
GetLength
(
I
4
);
return
b0_grid_desc
.
GetLength
(
I1
)
*
b0_grid_desc
.
GetLength
(
I2
)
*
b0_grid_desc
.
GetLength
(
I
5
);
}
}();
...
...
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
View file @
bee4e344
...
...
@@ -186,7 +186,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename
MPerWmma
,
typename
AK1
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_
AK0PerWmma_
AKRow_MPerWmma_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
WmmaK
&
,
const
MRepeat
&
,
...
...
@@ -194,17 +194,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const
MPerWmma
&
,
const
AK1
&
)
{
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBlock
;
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AKWmma
=
K
/
WmmaK
{};
constexpr
auto
AKRow
=
WmmaK
{}
/
AK1
{};
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBlock
;
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AKWmma
=
K
/
WmmaK
{};
constexpr
auto
AKRow
=
2
;
constexpr
auto
AK0PerWmma
=
WmmaK
{}
/
AKRow
/
AK1
{};
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
AKRow
,
AK1
{})),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
Number
<
AK0PerWmma
>
{},
Number
<
AKRow
>
{},
AK1
{})),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
{},
MWaves
{},
MPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
//
...
...
@@ -254,7 +256,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename
LPerWmma
,
typename
BK1
>
__host__
__device__
static
constexpr
auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1
(
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_
BK0PerWmma_
BKRow_LPerWmma_BK1
(
const
BGridDesc_L_K
&
b_grid_desc_l_k
,
const
WmmaK
&
,
const
LRepeat
&
,
...
...
@@ -262,17 +264,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const
LPerWmma
&
,
const
BK1
&
)
{
const
auto
L0
=
b_grid_desc_l_k
.
GetLength
(
I0
)
/
NPerBlock
;
const
auto
K
=
b_grid_desc_l_k
.
GetLength
(
I1
);
const
auto
BKWmma
=
K
/
WmmaK
{};
constexpr
auto
BKRow
=
WmmaK
{}
/
BK1
{};
const
auto
L0
=
b_grid_desc_l_k
.
GetLength
(
I0
)
/
NPerBlock
;
const
auto
K
=
b_grid_desc_l_k
.
GetLength
(
I1
);
const
auto
BKWmma
=
K
/
WmmaK
{};
constexpr
auto
BKRow
=
2
;
constexpr
auto
BK0PerWmma
=
WmmaK
{}
/
BKRow
/
BK1
{};
return
transform_tensor_descriptor
(
b_grid_desc_l_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BKWmma
,
BKRow
,
BK1
{})),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BKWmma
,
Number
<
BK0PerWmma
>
{},
Number
<
BKRow
>
{},
BK1
{})),
make_unmerge_transform
(
make_tuple
(
L0
*
LRepeat
{},
LWaves
{},
LPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
//
...
...
@@ -323,7 +327,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename
NPerWmma
,
typename
BL1
>
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1
(
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_
_BL0PerWmma_
BLRow_NPerWmma_BL1
(
const
BGridDesc_N_L
&
b_grid_desc_n_l
,
const
WmmaL
&
,
const
NRepeat
&
,
...
...
@@ -331,17 +335,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const
NPerWmma
&
,
const
BL1
&
)
{
const
auto
N0
=
b_grid_desc_n_l
.
GetLength
(
I0
)
/
OPerBlock
;
const
auto
L
=
b_grid_desc_n_l
.
GetLength
(
I1
);
const
auto
BLWmma
=
L
/
WmmaL
{};
constexpr
auto
BLRow
=
WmmaL
{}
/
BL1
{};
const
auto
N0
=
b_grid_desc_n_l
.
GetLength
(
I0
)
/
OPerBlock
;
const
auto
L
=
b_grid_desc_n_l
.
GetLength
(
I1
);
const
auto
BLWmma
=
L
/
WmmaL
{};
constexpr
auto
BLRow
=
2
;
constexpr
auto
BL0PerWmma
=
WmmaL
{}
/
BLRow
/
BL1
{};
return
transform_tensor_descriptor
(
b_grid_desc_n_l
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BLWmma
,
BLRow
,
BL1
{})),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BLWmma
,
Number
<
BL0PerWmma
>
{},
Number
<
BLRow
>
{},
BL1
{})),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
{},
NWaves
{},
NPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
//
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment