Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bee4e344
Commit
bee4e344
authored
May 19, 2023
by
aska-0096
Browse files
(5/5) attention pass, todo: debug lds perf bug
parent
fd4ff3a7
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
276 additions
and
248 deletions
+276
-248
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
...ple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+48
-49
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+36
-25
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+18
-19
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+15
-15
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+130
-117
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+27
-21
No files found.
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
View file @
bee4e344
...
@@ -53,8 +53,8 @@ using DeviceConvFwdInstance =
...
@@ -53,8 +53,8 @@ using DeviceConvFwdInstance =
GemmSpec
,
// GemmSpecialization
GemmSpec
,
// GemmSpecialization
1
,
// Prefetch stage
1
,
// Prefetch stage
128
,
// BlockSize
128
,
// BlockSize
64
,
// MPerBlock
64
,
// MPerBlock
64
,
// NPerBlock
64
,
// NPerBlock
64
,
// KPerBlock
64
,
// KPerBlock
4
,
// K1
4
,
// K1
16
,
// MPerWMMA
16
,
// MPerWMMA
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
bee4e344
...
@@ -305,7 +305,7 @@ struct BlockwiseGemmWMMA
...
@@ -305,7 +305,7 @@ struct BlockwiseGemmWMMA
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
// read A
a_thread_copy_
.
Run
(
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
...
@@ -365,58 +365,57 @@ struct BlockwiseGemmWMMA
...
@@ -365,58 +365,57 @@ struct BlockwiseGemmWMMA
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
// k=0,kpack*1, ... read B
// read B
b_thread_copy_
.
Run
(
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
// read A
// read A
a_thread_copy_
.
Run
(
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
a_thread_buf
);
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
/
B_KRow
,
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
n0
,
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
0
,
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
(
i
/
B_K1
)
%
B_KRow
,
make_tuple
(
i
/
B_K1
/
B_KRow
,
0
,
n0
,
i
%
B_K1
))
>
{}];
0
,
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
(
i
/
B_K1
)
%
B_KRow
,
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
0
,
make_tuple
(
i
/
A_K1
/
A_KRow
,
i
%
B_K1
))
>
{}];
m0
,
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
0
,
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
(
i
/
A_K1
)
%
A_KRow
,
make_tuple
(
i
/
A_K1
/
A_KRow
,
0
,
m0
,
i
%
A_K1
))
>
{}];
0
,
});
(
i
/
A_K1
)
%
A_KRow
,
0
,
i
%
A_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
});
});
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
bee4e344
...
@@ -136,6 +136,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -136,6 +136,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
WmmaK
=
16
;
...
@@ -175,13 +176,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -175,13 +176,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
}
else
else
{
{
return
Transform
::
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
return
Transform
::
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1
(
Number
<
WmmaK
>
{},
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
Number
<
MRepeat
>
{},
a_gs_ms_ks_strides_vec
),
Number
<
MWaves
>
{},
Number
<
WmmaK
>
{},
Number
<
MPerWmma
>
{},
Number
<
MRepeat
>
{},
Number
<
K1
>
{});
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{},
Number
<
K1
>
{});
}
}
}
}
...
@@ -197,14 +200,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -197,14 +200,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
}
else
else
{
{
return
Transform
::
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1
(
return
Transform
::
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1
(
b0_gs_ls_ks_strides_vec
),
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
Number
<
WmmaK
>
{},
b0_gs_ls_ks_strides_vec
),
Number
<
LRepeat
>
{},
Number
<
WmmaK
>
{},
Number
<
LWaves
>
{},
Number
<
LRepeat
>
{},
Number
<
LPerWmma
>
{},
Number
<
LWaves
>
{},
Number
<
K1
>
{});
Number
<
LPerWmma
>
{},
Number
<
K1
>
{});
}
}
}
}
...
@@ -220,14 +224,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -220,14 +224,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
}
else
else
{
{
return
Transform
::
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1
(
return
Transform
::
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1
(
b1_gs_ns_ls_strides_vec
),
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
Number
<
WmmaK
>
{},
b1_gs_ns_ls_strides_vec
),
Number
<
NRepeat
>
{},
Number
<
WmmaK
>
{},
Number
<
NWaves
>
{},
Number
<
NRepeat
>
{},
Number
<
NPerWmma
>
{},
Number
<
NWaves
>
{},
Number
<
L1
>
{});
Number
<
NPerWmma
>
{},
Number
<
L1
>
{});
}
}
}
}
...
@@ -521,7 +526,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -521,7 +526,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
else
else
{
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I3
)
*
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I3
)
*
arg
.
a_grid_desc
.
GetLength
(
I
5
);
arg
.
a_grid_desc
.
GetLength
(
I
4
)
*
arg
.
a_grid_desc
.
GetLength
(
I6
);
}
}
}();
}();
...
@@ -826,7 +831,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -826,7 +831,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
<<
">"
<<
" NumPrefetch: "
<<
" AEnableLds: "
<<
AEnableLds
<<
", "
<<
"B0EnableLds: "
<<
B0EnableLds
<<
", "
<<
"B1EnableLds: "
<<
B1EnableLds
<<
", "
<<
"NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
bee4e344
...
@@ -468,26 +468,25 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -468,26 +468,25 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
has_main_k_block_loop
>
;
// Last Option is W/O
has_main_k_block_loop
>
;
// Last Option is W/O
return
return
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
p_e_grid_
,
arg
.
a_grid_desc
,
arg
.
a_grid_desc
,
arg
.
b_grid_desc
,
arg
.
b_grid_desc
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
cde_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
};
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
bee4e344
...
@@ -398,21 +398,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -398,21 +398,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
has_main_k_block_loop
>
;
has_main_k_block_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
};
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
bee4e344
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
View file @
bee4e344
...
@@ -186,7 +186,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -186,7 +186,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename
MPerWmma
,
typename
MPerWmma
,
typename
AK1
>
typename
AK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_
AK0PerWmma_
AKRow_MPerWmma_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
WmmaK
&
,
const
WmmaK
&
,
const
MRepeat
&
,
const
MRepeat
&
,
...
@@ -194,17 +194,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -194,17 +194,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const
MPerWmma
&
,
const
MPerWmma
&
,
const
AK1
&
)
const
AK1
&
)
{
{
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBlock
;
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBlock
;
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AKWmma
=
K
/
WmmaK
{};
const
auto
AKWmma
=
K
/
WmmaK
{};
constexpr
auto
AKRow
=
WmmaK
{}
/
AK1
{};
constexpr
auto
AKRow
=
2
;
constexpr
auto
AK0PerWmma
=
WmmaK
{}
/
AKRow
/
AK1
{};
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
AKRow
,
AK1
{})),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
Number
<
AK0PerWmma
>
{},
Number
<
AKRow
>
{},
AK1
{})),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
{},
MWaves
{},
MPerWmma
{}))),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
{},
MWaves
{},
MPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
//
//
...
@@ -254,7 +256,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -254,7 +256,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename
LPerWmma
,
typename
LPerWmma
,
typename
BK1
>
typename
BK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1
(
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_
BK0PerWmma_
BKRow_LPerWmma_BK1
(
const
BGridDesc_L_K
&
b_grid_desc_l_k
,
const
BGridDesc_L_K
&
b_grid_desc_l_k
,
const
WmmaK
&
,
const
WmmaK
&
,
const
LRepeat
&
,
const
LRepeat
&
,
...
@@ -262,17 +264,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -262,17 +264,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const
LPerWmma
&
,
const
LPerWmma
&
,
const
BK1
&
)
const
BK1
&
)
{
{
const
auto
L0
=
b_grid_desc_l_k
.
GetLength
(
I0
)
/
NPerBlock
;
const
auto
L0
=
b_grid_desc_l_k
.
GetLength
(
I0
)
/
NPerBlock
;
const
auto
K
=
b_grid_desc_l_k
.
GetLength
(
I1
);
const
auto
K
=
b_grid_desc_l_k
.
GetLength
(
I1
);
const
auto
BKWmma
=
K
/
WmmaK
{};
const
auto
BKWmma
=
K
/
WmmaK
{};
constexpr
auto
BKRow
=
WmmaK
{}
/
BK1
{};
constexpr
auto
BKRow
=
2
;
constexpr
auto
BK0PerWmma
=
WmmaK
{}
/
BKRow
/
BK1
{};
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b_grid_desc_l_k
,
b_grid_desc_l_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BKWmma
,
BKRow
,
BK1
{})),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BKWmma
,
Number
<
BK0PerWmma
>
{},
Number
<
BKRow
>
{},
BK1
{})),
make_unmerge_transform
(
make_tuple
(
L0
*
LRepeat
{},
LWaves
{},
LPerWmma
{}))),
make_unmerge_transform
(
make_tuple
(
L0
*
LRepeat
{},
LWaves
{},
LPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
//
//
...
@@ -323,7 +327,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -323,7 +327,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename
NPerWmma
,
typename
NPerWmma
,
typename
BL1
>
typename
BL1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1
(
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_
_BL0PerWmma_
BLRow_NPerWmma_BL1
(
const
BGridDesc_N_L
&
b_grid_desc_n_l
,
const
BGridDesc_N_L
&
b_grid_desc_n_l
,
const
WmmaL
&
,
const
WmmaL
&
,
const
NRepeat
&
,
const
NRepeat
&
,
...
@@ -331,17 +335,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -331,17 +335,19 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const
NPerWmma
&
,
const
NPerWmma
&
,
const
BL1
&
)
const
BL1
&
)
{
{
const
auto
N0
=
b_grid_desc_n_l
.
GetLength
(
I0
)
/
OPerBlock
;
const
auto
N0
=
b_grid_desc_n_l
.
GetLength
(
I0
)
/
OPerBlock
;
const
auto
L
=
b_grid_desc_n_l
.
GetLength
(
I1
);
const
auto
L
=
b_grid_desc_n_l
.
GetLength
(
I1
);
const
auto
BLWmma
=
L
/
WmmaL
{};
const
auto
BLWmma
=
L
/
WmmaL
{};
constexpr
auto
BLRow
=
WmmaL
{}
/
BL1
{};
constexpr
auto
BLRow
=
2
;
constexpr
auto
BL0PerWmma
=
WmmaL
{}
/
BLRow
/
BL1
{};
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b_grid_desc_n_l
,
b_grid_desc_n_l
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BLWmma
,
BLRow
,
BL1
{})),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BLWmma
,
Number
<
BL0PerWmma
>
{},
Number
<
BLRow
>
{},
BL1
{})),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
{},
NWaves
{},
NPerWmma
{}))),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
{},
NWaves
{},
NPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
//
//
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment