Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
381a7317
Commit
381a7317
authored
Aug 04, 2023
by
letaoqin
Browse files
redefine interface
parent
6fa4feac
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
142 deletions
+63
-142
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2.cpp
...tten_bias/batched_multihead_attention_bias_forward_v2.cpp
+5
-5
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
...ten_bias/run_batched_multihead_attention_bias_forward.inc
+14
-21
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
+0
-64
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_bias_xdl_cshuffle_v2.hpp
...vice/impl/device_batched_mha_fwd_bias_xdl_cshuffle_v2.hpp
+44
-52
No files found.
example/52_flash_atten_bias/batched_multihead_attention_bias_forward_v2.cpp
View file @
381a7317
...
@@ -9,7 +9,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
...
@@ -9,7 +9,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
Gemm1
Gemm1
*/
*/
#define DIM
64
// DIM should be a multiple of 8.
#define DIM
128
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -53,7 +53,7 @@ using CDataType = DataType;
...
@@ -53,7 +53,7 @@ using CDataType = DataType;
using
DDataType
=
F16
;
using
DDataType
=
F16
;
using
ZDataType
=
U16
;
// INT32
using
ZDataType
=
U16
;
// INT32
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<
DDataType
>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
static
constexpr
ck
::
index_t
NumDimG
=
2
;
...
@@ -80,7 +80,7 @@ static constexpr bool Deterministic = false;
...
@@ -80,7 +80,7 @@ static constexpr bool Deterministic = false;
#if(DIM <= 32)
#if(DIM <= 32)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttention
Bias
Forward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
R2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -151,7 +151,7 @@ using DeviceGemmInstance =
...
@@ -151,7 +151,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 64)
#elif(DIM <= 64)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttention
Bias
Forward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
R2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
@@ -222,7 +222,7 @@ using DeviceGemmInstance =
...
@@ -222,7 +222,7 @@ using DeviceGemmInstance =
Deterministic
>
;
Deterministic
>
;
#elif(DIM <= 128)
#elif(DIM <= 128)
using
DeviceGemmInstance
=
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttention
Bias
Forward_Xdl_CShuffle_V2
<
ck
::
tensor_operation
::
device
::
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
R2
<
NumDimG
,
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
...
...
example/52_flash_atten_bias/run_batched_multihead_attention_bias_forward.inc
View file @
381a7317
...
@@ -137,7 +137,7 @@ int run(int argc, char* argv[])
...
@@ -137,7 +137,7 @@ int run(int argc, char* argv[])
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DDataType
>
{
-
1
,
1
});
d_gs_ms_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
DDataType
>
{
-
2
,
2
});
break
;
break
;
case
2
:
case
2
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
...
@@ -188,11 +188,10 @@ int run(int argc, char* argv[])
...
@@ -188,11 +188,10 @@ int run(int argc, char* argv[])
static_cast
<
B0DataType
*>
(
b0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
nullptr
),
static_cast
<
ZDataType
*>
(
nullptr
),
static_cast
<
ZDataType
*>
(
nullptr
),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
{
},
// std::array<void*, 1> p_acc0_biases;
std
::
array
<
void
*
,
1
>
{
d_device_buf
.
GetDeviceBuffer
()
},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
{},
// std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_lengths
,
...
@@ -201,13 +200,11 @@ int run(int argc, char* argv[])
...
@@ -201,13 +200,11 @@ int run(int argc, char* argv[])
b1_gs_os_ns_strides
,
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
c_gs_ms_os_strides
,
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
z_gs_ms_ns_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
{},
//
std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_lengths
},
//
acc0_biases_gs_ms_ns_lengths
{},
//
std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides
},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_strides
},
//
acc0_biases_gs_ms_ns_strides
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
a_element_op
,
a_element_op
,
...
@@ -244,18 +241,17 @@ int run(int argc, char* argv[])
...
@@ -244,18 +241,17 @@ int run(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
// run for storing z tensor
// run for storing z tensor
argument
=
gemm
.
MakeArgument
(
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B0DataType
*>
(
b0_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
B1DataType
*>
(
b1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DDataType
*>
(
d_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ZDataType
*>
(
z_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
static_cast
<
LSEDataType
*>
(
lse_device_buf
.
GetDeviceBuffer
()),
{},
// std::array<void*, 1> p_acc0_biases;
std
::
array
<
void
*
,
1
>
{
{},
// std::array<void*, 1> p_acc1_biases;
d_device_buf
.
GetDeviceBuffer
()},
// std::array<void*, 1> p_acc0_biases;
{},
// std::array<void*, 1> p_acc1_biases;
a_gs_ms_ks_lengths
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
a_gs_ms_ks_strides
,
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_lengths
,
...
@@ -264,13 +260,13 @@ int run(int argc, char* argv[])
...
@@ -264,13 +260,13 @@ int run(int argc, char* argv[])
b1_gs_os_ns_strides
,
b1_gs_os_ns_strides
,
c_gs_ms_os_lengths
,
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
,
c_gs_ms_os_strides
,
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
z_gs_ms_ns_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_lengths},
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
{},
// std::array<std::vector<ck::index_t>, 1>{acc0_biases_gs_ms_ns_strides},
d_gs_ms_ns_lengths
},
// acc0_biases_gs_ms_ns_lengths
std
::
array
<
std
::
vector
<
ck
::
index_t
>
,
1
>
{
d_gs_ms_ns_strides
},
// acc0_biases_gs_ms_ns_strides
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_lengths},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
{},
// std::array<std::vector<ck::index_t>, 1>{acc1_biases_gs_ms_os_strides},
a_element_op
,
a_element_op
,
...
@@ -326,11 +322,8 @@ int run(int argc, char* argv[])
...
@@ -326,11 +322,8 @@ int run(int argc, char* argv[])
a_g_m_k
,
b0_g_k_n
,
acc0_g_m_n
,
a_element_op
,
b0_element_op
,
acc0_element_op
);
a_g_m_k
,
b0_g_k_n
,
acc0_g_m_n
,
a_element_op
,
b0_element_op
,
acc0_element_op
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// bias
//bias
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
d_g_m_n
(
idx
);
});
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
+=
d_g_m_n
(
idx
);
});
// masking
// masking
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
const
auto
mask
=
DeviceGemmInstance
::
C0MatrixMask
(
M
,
N
);
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp
View file @
381a7317
...
@@ -127,70 +127,6 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
...
@@ -127,70 +127,6 @@ struct DeviceBatchedMultiheadAttentionForward : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
ZDataType
,
typename
LSEDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceBatchedMultiheadAttentionBiasForward
:
public
BaseOperator
{
static
constexpr
index_t
NumAcc0Bias
=
Acc0BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b0
,
const
void
*
p_b1
,
void
*
p_c
,
const
void
*
p_d
,
void
*
p_z
,
void
*
p_lse
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc1Bias
>
p_acc1_biases
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_lengths
,
// b1_gs_os_ns_lengths
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
// d_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
,
// d_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
// z_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
// z_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
// lse_gs_ms_lengths
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc0Bias
>
acc0_biases_gs_ms_ns_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_lengths
,
// acc1_biases_gs_ms_os_lengths
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumAcc1Bias
>
acc1_biases_gs_ms_gemm1ns_strides
,
// acc1_biases_gs_ms_os_strides
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
Acc0ElementwiseOperation
acc0_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
,
float
p_dropout
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_fwd_bias_xdl_cshuffle_v2.hpp
View file @
381a7317
...
@@ -124,9 +124,9 @@ __global__ void
...
@@ -124,9 +124,9 @@ __global__ void
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_c_grid
+
c_batch_offset
,
nullptr
?
nullptr
:
p_d_grid
+
d_batch_offset
,
p_d_grid
==
nullptr
?
nullptr
:
p_d_grid
+
d_batch_offset
,
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_lse_grid
==
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -157,9 +157,9 @@ __global__ void
...
@@ -157,9 +157,9 @@ __global__ void
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_d_grid
+
d_batch_offset
,
p_d_grid
==
nullptr
?
nullptr
:
p_d_grid
+
d_batch_offset
,
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
,
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_lse_grid
==
nullptr
?
nullptr
:
p_lse_grid
+
lse_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -288,28 +288,27 @@ template <index_t NumDimG,
...
@@ -288,28 +288,27 @@ template <index_t NumDimG,
MaskingSpecialization
MaskingSpec
,
MaskingSpecialization
MaskingSpec
,
bool
Deterministic
,
bool
Deterministic
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedMultiheadAttention
Bias
Forward_Xdl_CShuffle_V2
struct
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
R2
:
public
DeviceBatchedMultiheadAttention
Bias
Forward
<
NumDimG
,
:
public
DeviceBatchedMultiheadAttentionForward
<
NumDimG
,
NumDimM
,
NumDimM
,
NumDimN
,
NumDimN
,
NumDimK
,
NumDimK
,
NumDimO
,
NumDimO
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
ZDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
MaskingSpec
>
MaskingSpec
>
{
{
using
DDataType
=
ADataType
;
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
"Number of dimension must be greater than 0"
);
...
@@ -317,7 +316,12 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -317,7 +316,12 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
static
constexpr
index_t
NumAcc1Bias
=
Acc1BiasDataType
::
Size
();
// TODO ANT: implement bias combination
// TODO ANT: implement bias combination
static_assert
(
NumAcc0Bias
==
0
&&
NumAcc0Bias
==
0
,
"Bias addition is unimplemented"
);
static_assert
(
NumAcc0Bias
<=
1
,
"Acc0 Bias addition is max support one bias"
);
static_assert
(
NumAcc1Bias
==
0
,
"Acc1 Bias addition is unimplemented"
);
static_assert
(
NumAcc1Bias
==
0
?
true
:
std
::
is_same_v
<
ADataType
,
ck
::
tuple_element_t
<
0
,
Acc0BiasDataType
>>
);
using
DDataType
=
ADataType
;
#if 0
#if 0
// TODO ANT: use alias
// TODO ANT: use alias
...
@@ -329,7 +333,7 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -329,7 +333,7 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
static constexpr index_t NumDimGemm1K = NumDimN;
static constexpr index_t NumDimGemm1K = NumDimN;
#endif
#endif
using
DeviceOp
=
DeviceBatchedMultiheadAttention
Bias
Forward_Xdl_CShuffle_V2
;
using
DeviceOp
=
DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
R2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -574,7 +578,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -574,7 +578,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b_grid
,
const
B1DataType
*
p_b1_grid
,
const
B1DataType
*
p_b1_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
const
DDataType
*
p_d_grid
,
ZDataType
*
p_z_grid
,
ZDataType
*
p_z_grid
,
LSEDataType
*
p_lse_grid
,
LSEDataType
*
p_lse_grid
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
...
@@ -587,8 +590,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -587,8 +590,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
...
@@ -609,7 +610,8 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -609,7 +610,8 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
p_d_grid_
{
p_d_grid
},
p_d_grid_
{
NumAcc0Bias
==
0
?
nullptr
:
static_cast
<
const
DDataType
*>
(
p_acc0_biases
[
0
])},
p_z_grid_
{
p_z_grid
},
p_z_grid_
{
p_z_grid
},
p_lse_grid_
{
p_lse_grid
},
p_lse_grid_
{
p_lse_grid
},
a_grid_desc_ak0_m_ak1_
{
a_grid_desc_ak0_m_ak1_
{
...
@@ -620,7 +622,10 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -620,7 +622,10 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
d_grid_desc_m_n_
{
MakeZGridDescriptor_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
)},
d_grid_desc_m_n_
{
NumAcc0Bias
==
0
?
DGridDesc_M_N
{}
:
MakeZGridDescriptor_M_N
(
acc0_biases_gs_ms_ns_lengths
[
0
],
acc0_biases_gs_ms_ns_strides
[
0
])},
z_grid_desc_m_n_
{
MakeZGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
z_grid_desc_m_n_
{
MakeZGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
lse_grid_desc_m_
{
DeviceOp
::
MakeLSEGridDescriptor_M
(
lse_gs_ms_lengths
[
NumDimG
])},
lse_grid_desc_m_
{
DeviceOp
::
MakeLSEGridDescriptor_M
(
lse_gs_ms_lengths
[
NumDimG
])},
a_grid_desc_g_m_k_
{
a_grid_desc_g_m_k_
{
...
@@ -631,8 +636,10 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -631,8 +636,10 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
b1_gs_gemm1ns_gemm1ks_lengths
,
b1_gs_gemm1ns_gemm1ks_strides
)},
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
c_gs_ms_gemm1ns_lengths
,
c_gs_ms_gemm1ns_strides
)},
c_gs_ms_gemm1ns_strides
)},
d_grid_desc_g_m_n_
{
d_grid_desc_g_m_n_
{
NumAcc0Bias
==
0
?
DGridDesc_G_M_N
{}
Transform
::
MakeCGridDescriptor_G_M_N
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
)},
:
Transform
::
MakeCGridDescriptor_G_M_N
(
acc0_biases_gs_ms_ns_lengths
[
0
],
acc0_biases_gs_ms_ns_strides
[
0
])},
z_grid_desc_g_m_n_
{
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
...
@@ -666,10 +673,7 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -666,10 +673,7 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
type_convert
<
index_t
>
(
lse_grid_desc_m_
.
GetElementSpaceSize
())}
{
{
// TODO ANT: implement bias addition
// TODO ANT: implement bias addition
ignore
=
p_acc0_biases
;
ignore
=
p_acc1_biases
;
ignore
=
p_acc1_biases
;
ignore
=
acc0_biases_gs_ms_ns_lengths
;
ignore
=
acc0_biases_gs_ms_ns_strides
;
ignore
=
acc1_biases_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_biases_gs_ms_gemm1ns_lengths
;
ignore
=
acc1_biases_gs_ms_gemm1ns_strides
;
ignore
=
acc1_biases_gs_ms_gemm1ns_strides
;
...
@@ -1052,7 +1056,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -1052,7 +1056,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
const
BDataType
*
p_b
,
const
BDataType
*
p_b
,
const
B1DataType
*
p_b1
,
const
B1DataType
*
p_b1
,
CDataType
*
p_c
,
CDataType
*
p_c
,
const
DDataType
*
p_d
,
ZDataType
*
p_z
,
ZDataType
*
p_z
,
LSEDataType
*
p_lse
,
LSEDataType
*
p_lse
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
...
@@ -1065,8 +1068,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -1065,8 +1068,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
...
@@ -1088,7 +1089,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -1088,7 +1089,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
p_b
,
p_b
,
p_b1
,
p_b1
,
p_c
,
p_c
,
p_d
,
p_z
,
p_z
,
p_lse
,
p_lse
,
p_acc0_biases
,
p_acc0_biases
,
...
@@ -1101,8 +1101,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -1101,8 +1101,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
z_gs_ms_ns_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
...
@@ -1128,7 +1126,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -1128,7 +1126,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
const
void
*
p_b
,
const
void
*
p_b
,
const
void
*
p_b1
,
const
void
*
p_b1
,
void
*
p_c
,
void
*
p_c
,
const
void
*
p_d
,
void
*
p_z
,
void
*
p_z
,
void
*
p_lse
,
void
*
p_lse
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
const
std
::
array
<
void
*
,
NumAcc0Bias
>
p_acc0_biases
,
...
@@ -1141,8 +1138,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -1141,8 +1138,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
d_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
const
std
::
vector
<
index_t
>&
lse_gs_ms_lengths
,
...
@@ -1164,7 +1159,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -1164,7 +1159,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
const
DDataType
*>
(
p_d
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
ZDataType
*>
(
p_z
),
static_cast
<
LSEDataType
*>
(
p_lse
),
static_cast
<
LSEDataType
*>
(
p_lse
),
p_acc0_biases
,
// cast in struct Argument
p_acc0_biases
,
// cast in struct Argument
...
@@ -1177,8 +1171,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -1177,8 +1171,6 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
b1_gs_gemm1ns_gemm1ks_strides
,
// b1_gs_os_ns_strides
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_lengths
,
// c_gs_ms_os_lengths
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
c_gs_ms_gemm1ns_strides
,
// c_gs_ms_os_strides
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
z_gs_ms_ns_strides
,
lse_gs_ms_lengths
,
lse_gs_ms_lengths
,
...
@@ -1207,7 +1199,7 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
...
@@ -1207,7 +1199,7 @@ struct DeviceBatchedMultiheadAttentionBiasForward_Xdl_CShuffle_V2
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"DeviceBatchedMultiheadAttention
Bias
Forward_Xdl_CShuffle_V2"
str
<<
"DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle_V2
R2
"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment