Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7e003d31
Commit
7e003d31
authored
Feb 28, 2023
by
aska-0096
Browse files
Porting new blockwise gemm to flash attention
parent
84b4ada5
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
209 additions
and
136 deletions
+209
-136
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+7
-7
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
...tmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
+6
-6
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+51
-35
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+121
-74
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+20
-12
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
7e003d31
...
@@ -100,12 +100,12 @@ using DeviceGemmInstance =
...
@@ -100,12 +100,12 @@ using DeviceGemmInstance =
32
,
// KPerBlock
32
,
// KPerBlock
8
,
// K1
8
,
// K1
// Gemm 1
// Gemm 1
64
,
// NPerBlock
64
,
// NPerBlock
32
,
// LPerBlock
32
,
// L
Tile
PerBlock
8
,
// L1
8
,
// L1
16
,
// MPerWMMA
16
,
// MPerWMMA
16
,
// LPerWMMA
16
,
// LPerWMMA
16
,
// NPerWMMA
16
,
// NPerWMMA
// Per repeat = wave_m = wave_num, wave_n = 1
// Per repeat = wave_m = wave_num, wave_n = 1
1
,
// MRepeat
1
,
// MRepeat
8
,
// LRepeat
8
,
// LRepeat
...
@@ -124,7 +124,7 @@ using DeviceGemmInstance =
...
@@ -124,7 +124,7 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
S
<
4
,
8
,
8
>
,
// B1BlockTransfer
L
N -> L0 N L1
S
<
4
,
8
,
8
>
,
// B1BlockTransfer N
L
-> L0 N L1
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
View file @
7e003d31
...
@@ -122,20 +122,20 @@ int run(int argc, char* argv[])
...
@@ -122,20 +122,20 @@ int run(int argc, char* argv[])
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
break
;
case
5
:
// Rand: b1 ; unit: a
b0 fail
case
5
:
// Rand: b1
b0
; unit: a
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_
1
<
B0DataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_
2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
break
;
case
6
:
// Rand: b0 ; unit:
a
b1 pass
case
6
:
// Rand:
a
b0 ; unit: b1 pass
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_
1
<
ADataType
>
{});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_
2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
break
;
case
7
:
// Rand: a ; unit: b0
b1
pass
case
7
:
// Rand: a
b1
; unit: b0 pass
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_
1
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_
2
<
B1DataType
>
{
-
2
,
2
});
break
;
break
;
default
:
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
7e003d31
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp"
#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#define CK_MNK_LOOP
#define CK_MNK_LOOP
...
@@ -340,6 +341,7 @@ struct BlockwiseGemmWMMA
...
@@ -340,6 +341,7 @@ struct BlockwiseGemmWMMA
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
b_thread_buf
);
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
...
@@ -413,7 +415,7 @@ struct BlockwiseGemmWMMA
...
@@ -413,7 +415,7 @@ struct BlockwiseGemmWMMA
A_K1
,
A_K1
,
0x76543210
,
0x76543210
,
0xfedcba98
,
0xfedcba98
,
true
>
;
TransposeC
?
false
:
true
>
;
};
};
template
<
bool
EnableLds
>
template
<
bool
EnableLds
>
...
@@ -448,7 +450,7 @@ struct BlockwiseGemmWMMA
...
@@ -448,7 +450,7 @@ struct BlockwiseGemmWMMA
B_K1
,
B_K1
,
0x76543210
,
0x76543210
,
0xfedcba98
,
0xfedcba98
,
false
>
;
TransposeC
?
true
:
false
>
;
};
};
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
7e003d31
...
@@ -56,11 +56,11 @@ template <index_t NumDimG,
...
@@ -56,11 +56,11 @@ template <index_t NumDimG,
ck
::
index_t
KPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
K1
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
LPerBlock
,
ck
::
index_t
L
Tile
PerBlock
,
ck
::
index_t
L1
,
ck
::
index_t
L1
,
ck
::
index_t
MPerW
MMA
,
ck
::
index_t
MPerW
mma
,
ck
::
index_t
LPerW
MMA
,
ck
::
index_t
LPerW
mma
,
ck
::
index_t
NPerW
MMA
,
ck
::
index_t
NPerW
mma
,
ck
::
index_t
MRepeat
,
ck
::
index_t
MRepeat
,
ck
::
index_t
LRepeat
,
ck
::
index_t
LRepeat
,
ck
::
index_t
NRepeat
,
ck
::
index_t
NRepeat
,
...
@@ -134,15 +134,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -134,15 +134,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
LWaves
=
LPerBlock
/
(
LRepeat
*
LPerWmma
);
static
constexpr
auto
LWaves
=
LPerBlock
/
(
LRepeat
*
LPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
AEnableLds
=
LWaves
==
1
?
false
:
true
;
static
constexpr
auto
AEnableLds
=
LWaves
==
1
?
false
:
true
;
//
static constexpr auto B0EnableLds = MWaves == 1 ? false : true;
static
constexpr
auto
B0EnableLds
=
MWaves
==
1
?
false
:
true
;
//
static constexpr auto B1EnableLds = MWaves == 1 ? false : true;
static
constexpr
auto
B1EnableLds
=
MWaves
==
1
?
false
:
true
;
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
...
@@ -165,14 +168,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -165,14 +168,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
else
else
{
{
return
Transform
::
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
return
Transform
::
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
WmmaK
,
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{},
Number
<
K1
>
{})
Number
<
WmmaK
>
{},
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{},
Number
<
K1
>
{});
}
}
}
}
static
auto
MakeB0GridDescriptor
_BK0_L_BK1
(
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths_vec
,
static
auto
MakeB0GridDescriptor
(
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides_vec
)
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides_vec
)
{
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
...
@@ -188,7 +194,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -188,7 +194,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
}
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
({},
{}));
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
({},
{}));
using
B0GridDesc_BK0_L_BK1
=
decltype
(
MakeB0GridDescriptor
_BK0_L_BK1
({},
{}));
using
B0GridDesc_BK0_L_BK1
=
decltype
(
MakeB0GridDescriptor
({},
{}));
using
B1GridDesc_BL0_N_BL1
=
decltype
(
MakeB1GridDescriptor_BL0_N_BL1
({},
{}));
using
B1GridDesc_BL0_N_BL1
=
decltype
(
MakeB1GridDescriptor_BL0_N_BL1
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
...
@@ -277,11 +283,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -277,11 +283,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
KPerBlock
,
KPerBlock
,
K1
,
K1
,
NPerBlock
,
NPerBlock
,
LPerBlock
,
L
Tile
PerBlock
,
L1
,
L1
,
MPerW
MMA
,
MPerW
mma
,
LPerW
MMA
,
LPerW
mma
,
NPerW
MMA
,
NPerW
mma
,
MRepeat
,
MRepeat
,
LRepeat
,
LRepeat
,
NRepeat
,
NRepeat
,
...
@@ -357,10 +363,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -357,10 +363,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_b0_grid_
{
p_b0_grid
},
p_b0_grid_
{
p_b0_grid
},
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc
{
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b0_grid_desc_bk0_l_bk1_
{
b0_grid_desc_bk0_l_bk1_
{
DeviceOp
::
MakeB0GridDescriptor_BK0_L_BK1
(
DeviceOp
::
MakeB0GridDescriptor
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b1_grid_desc_bl0_n_bl1_
{
DeviceOp
::
MakeB1GridDescriptor_BL0_N_BL1
(
b1_grid_desc_bl0_n_bl1_
{
DeviceOp
::
MakeB1GridDescriptor_BL0_N_BL1
(
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
c_grid_desc_m_n_
{
c_grid_desc_m_n_
{
...
@@ -405,7 +410,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -405,7 +410,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ignore
=
acc1_biases_gs_ms_ns_lengths
;
ignore
=
acc1_biases_gs_ms_ns_lengths
;
ignore
=
acc1_biases_gs_ms_ns_strides
;
ignore
=
acc1_biases_gs_ms_ns_strides
;
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc
_ak0_m_ak1_
,
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc
,
b0_grid_desc_bk0_l_bk1_
,
b0_grid_desc_bk0_l_bk1_
,
b1_grid_desc_bl0_n_bl1_
,
b1_grid_desc_bl0_n_bl1_
,
c_grid_desc_m_n_
,
c_grid_desc_m_n_
,
...
@@ -424,7 +429,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -424,7 +429,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
// Tensor Descriptors
// Tensor Descriptors
AGridDesc
a_grid_desc
_ak0_m_ak1_
;
AGridDesc
a_grid_desc
;
B0GridDesc_BK0_L_BK1
b0_grid_desc_bk0_l_bk1_
;
B0GridDesc_BK0_L_BK1
b0_grid_desc_bk0_l_bk1_
;
B1GridDesc_BL0_N_BL1
b1_grid_desc_bl0_n_bl1_
;
B1GridDesc_BL0_N_BL1
b1_grid_desc_bl0_n_bl1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
...
@@ -473,8 +478,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -473,8 +478,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
const
auto
K
=
const
auto
K
=
[
&
]()
{
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
if
constexpr
(
AEnableLds
)
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I2
);
}
else
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I3
)
*
arg
.
a_grid_desc
.
GetLength
(
I5
);
}
}();
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
<
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
<
...
@@ -506,7 +520,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -506,7 +520,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
arg
.
p_b0_grid_
,
arg
.
p_b0_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc
_ak0_m_ak1_
,
arg
.
a_grid_desc
,
arg
.
b0_grid_desc_bk0_l_bk1_
,
arg
.
b0_grid_desc_bk0_l_bk1_
,
arg
.
b1_grid_desc_bl0_n_bl1_
,
arg
.
b1_grid_desc_bl0_n_bl1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
...
@@ -551,20 +565,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -551,20 +565,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
{
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
{
printf
(
"DeviceOp: Acc0 Type err"
);
return
false
;
return
false
;
}
}
if
constexpr
(
!
(
is_same_v
<
Acc1DataType
,
float
>
||
is_same_v
<
Acc1DataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
Acc1DataType
,
float
>
||
is_same_v
<
Acc1DataType
,
int32_t
>
))
{
{
printf
(
"DeviceOp: Acc1 Type err"
);
return
false
;
return
false
;
}
}
}
}
else
else
{
{
printf
(
"DeviceOp: Arch err"
);
return
false
;
return
false
;
}
}
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
_ak0_m_ak1_
,
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
,
arg
.
b0_grid_desc_bk0_l_bk1_
,
arg
.
b0_grid_desc_bk0_l_bk1_
,
arg
.
b1_grid_desc_bl0_n_bl1_
,
arg
.
b1_grid_desc_bl0_n_bl1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
...
@@ -574,14 +591,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -574,14 +591,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
}
// Check if C permute dimension matches GEMM + GEMM shape
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_m
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_n
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_n
=
arg
.
b1_grid_desc_bl0_n_bl1_
.
GetLength
(
I1
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_n
==
b1_n
))
if
(
!
(
c_g
==
arg
.
batch_count_
))
{
{
printf
(
"DeviceOp: BatchCount err"
);
return
false
;
return
false
;
}
}
...
@@ -604,6 +618,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -604,6 +618,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
{
printf
(
"DeviceOp: Data Transfer Vector scalar err"
);
return
false
;
return
false
;
}
}
...
@@ -619,6 +634,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -619,6 +634,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
if
(
!
(
a_stride_lowest
==
1
||
b0_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
if
(
!
(
a_stride_lowest
==
1
||
b0_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
c_stride_lowest
==
1
))
c_stride_lowest
==
1
))
{
{
printf
(
"DeviceOp: Data Vectorize transfer err"
);
return
false
;
return
false
;
}
}
...
@@ -765,7 +781,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -765,7 +781,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<<
K1
<<
", "
<<
K1
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
LPerBlock
<<
", "
<<
L
Tile
PerBlock
<<
", "
<<
L1
<<
L1
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
7e003d31
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
View file @
7e003d31
...
@@ -179,24 +179,32 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -179,24 +179,32 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
typename
AGridDesc_M_K
,
typename
Number
>
template
<
typename
AGridDesc_M_K
,
typename
WmmaK
,
typename
MRepeat
,
typename
MWaves
,
typename
MPerWmma
,
typename
AK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
Number
&
WmmaK
,
const
Number
&
MRepeat
,
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
Number
&
MWaves
,
const
Number
&
MPerWmma
,
const
Number
&
AK1
)
const
WmmaK
&
,
const
MRepeat
&
,
const
MWaves
&
,
const
MPerWmma
&
,
const
AK1
&
)
{
{
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBl
c
ok
;
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBlo
c
k
;
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AKWmma
=
K
/
WmmaK
;
const
auto
AKWmma
=
K
/
WmmaK
{}
;
constexpr
auto
AKRow
=
WmmaK
/
K1
;
constexpr
auto
AKRow
=
WmmaK
{}
/
A
K1
{}
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
Number
<
AKRow
>
{},
AK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
AKRow
,
AK1
{})),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
{},
MWaves
{},
MPerWmma
{}))),
make_tuple
(
M0
*
MRepeat
,
MWaves
,
MPerWmma
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
}
}
//
//
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment