Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
6a77a6da
Commit
6a77a6da
authored
Jul 08, 2022
by
Tri Dao
Browse files
Refactor gemm_cl to template on either __half or __nv_bfloat16
parent
e518a4b3
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
49 additions
and
32 deletions
+49
-32
csrc/flash_attn/src/fmha/gemm.h
csrc/flash_attn/src/fmha/gemm.h
+19
-2
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
+12
-12
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
+1
-1
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+12
-12
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+5
-5
No files found.
csrc/flash_attn/src/fmha/gemm.h
View file @
6a77a6da
...
@@ -253,9 +253,26 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N])
...
@@ -253,9 +253,26 @@ inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N])
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////
/// Statically maps half types => cutlass data types
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Type_
>
struct
HalfTypeToCutlassType
{
using
Type
=
Type_
;
};
/// Statically maps __half => cutlass::half_t
template
<
>
struct
HalfTypeToCutlassType
<
__half
>
{
using
Type
=
cutlass
::
half_t
;
};
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
template
<
>
struct
HalfTypeToCutlassType
<
__nv_bfloat16
>
{
using
Type
=
cutlass
::
bfloat16_t
;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Acc
,
typename
A
,
typename
B
,
int
M
,
int
N
>
template
<
typename
elem_type
,
typename
Acc
,
typename
A
,
typename
B
,
int
M
,
int
N
>
inline
__device__
void
gemm_cl
(
Acc
(
&
acc
)[
M
][
N
],
const
A
(
&
a
)[
M
],
const
B
(
&
b
)[
N
])
{
inline
__device__
void
gemm_cl
(
Acc
(
&
acc
)[
M
][
N
],
const
A
(
&
a
)[
M
],
const
B
(
&
b
)[
N
])
{
using
Shape
=
cutlass
::
gemm
::
GemmShape
<
16
*
M
,
16
*
N
,
16
>
;
using
Shape
=
cutlass
::
gemm
::
GemmShape
<
16
*
M
,
16
*
N
,
16
>
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
...
@@ -267,7 +284,7 @@ inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N
...
@@ -267,7 +284,7 @@ inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N
// TD [2022-06-02] We don't support Volta (SM70) yet.
// TD [2022-06-02] We don't support Volta (SM70) yet.
assert
(
0
);
assert
(
0
);
#endif
#endif
using
Element
=
cutlass
::
half_t
;
using
Element
=
typename
HalfTypeToCutlassType
<
elem_type
>::
Type
;
using
ElementC
=
float
;
using
ElementC
=
float
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
...
...
csrc/flash_attn/src/fmha_block_dgrad_kernel_1xN_loop.h
View file @
6a77a6da
...
@@ -407,9 +407,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -407,9 +407,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
smem_do
.
load
(
frag_do
[
ki
&
1
],
ki
);
smem_do
.
load
(
frag_do
[
ki
&
1
],
ki
);
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
smem_v
.
load
(
frag_v
[
ki
&
1
],
ki
);
smem_v
.
load
(
frag_v
[
ki
&
1
],
ki
);
fmha
::
gemm_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
}
else
{
}
else
{
fmha
::
gemm_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[
ki
-
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[
ki
-
1
]);
}
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
...
@@ -423,9 +423,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -423,9 +423,9 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
{
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
int
ki
=
Mma_tile_p
::
MMAS_K
;
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
fmha
::
gemm_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
}
else
{
}
else
{
fmha
::
gemm_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)]);
fmha
::
gemm_cl
<
__half
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)]);
}
}
}
}
...
@@ -514,14 +514,14 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -514,14 +514,14 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
// Trigger the load from shared memory for the next series of Q values.
// Trigger the load from shared memory for the next series of Q values.
smem_kt
.
load
(
frag_kt
[
ki
&
1
],
ki
);
smem_kt
.
load
(
frag_kt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
// Do the math for the values already in registers.
fmha
::
gemm_cl
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
// fmha::gemm_cl
<__half>
(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
}
// Do the final stage of math.
// Do the final stage of math.
{
{
int
ki
=
Mma_tile_dq
::
MMAS_K
;
int
ki
=
Mma_tile_dq
::
MMAS_K
;
fmha
::
gemm_cl
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
// fmha::gemm_cl
<__half>
(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
}
static_assert
(
Gmem_tile_dq
::
LOOPS
==
1
);
static_assert
(
Gmem_tile_dq
::
LOOPS
==
1
);
...
@@ -554,13 +554,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -554,13 +554,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
// Trigger the load from shared memory for the next series of Q values.
// Trigger the load from shared memory for the next series of Q values.
smem_dot
.
load
(
frag_dot
[
ki
&
1
],
ki
);
smem_dot
.
load
(
frag_dot
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
// Do the math for the values already in registers.
fmha
::
gemm_cl
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
}
}
// Do the final stage of math.
// Do the final stage of math.
{
{
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
fmha
::
gemm_cl
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
}
}
// __syncthreads();
// __syncthreads();
...
@@ -612,13 +612,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
...
@@ -612,13 +612,13 @@ inline __device__ void compute_block_dq_dk_dv_1xN_one_iter(const Params ¶ms,
// Trigger the load from shared memory for the next series of Q values.
// Trigger the load from shared memory for the next series of Q values.
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
// Do the math for the values already in registers.
fmha
::
gemm_cl
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
}
// Do the final stage of math.
// Do the final stage of math.
{
{
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
fmha
::
gemm_cl
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
}
// Make sure dQ is in shared memory.
// Make sure dQ is in shared memory.
...
...
csrc/flash_attn/src/fmha_block_fprop_kernel_1xN.h
View file @
6a77a6da
...
@@ -365,7 +365,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
...
@@ -365,7 +365,7 @@ inline __device__ void device_block_1xN_(const Params ¶ms, const int bidb, c
// Do this part of O = P^T * V^T.
// Do this part of O = P^T * V^T.
#pragma unroll
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm_cl
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
fmha
::
gemm_cl
<
__half
>
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
}
}
// The mapping from tidx to rows changes between the softmax and the O-reduction.
// The mapping from tidx to rows changes between the softmax and the O-reduction.
...
...
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
6a77a6da
...
@@ -369,9 +369,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -369,9 +369,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
smem_do
.
load
(
frag_do
[
ki
&
1
],
ki
);
smem_do
.
load
(
frag_do
[
ki
&
1
],
ki
);
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
smem_v
.
load
(
frag_v
[
ki
&
1
],
ki
);
smem_v
.
load
(
frag_v
[
ki
&
1
],
ki
);
fmha
::
gemm_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
}
else
{
}
else
{
fmha
::
gemm_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[
ki
-
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[
ki
-
1
]);
}
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
...
@@ -385,9 +385,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -385,9 +385,9 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
{
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
int
ki
=
Mma_tile_p
::
MMAS_K
;
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
if
(
!
Kernel_traits
::
V_IN_REGS
)
{
fmha
::
gemm_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)
&
1
]);
}
else
{
}
else
{
fmha
::
gemm_cl
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)]);
fmha
::
gemm_cl
<
__half
>
(
acc_dp
,
frag_do
[(
ki
-
1
)
&
1
],
frag_v
[(
ki
-
1
)]);
}
}
}
}
...
@@ -442,14 +442,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -442,14 +442,14 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// Trigger the load from shared memory for the next series of Q values.
// Trigger the load from shared memory for the next series of Q values.
smem_kt
.
load
(
frag_kt
[
ki
&
1
],
ki
);
smem_kt
.
load
(
frag_kt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
// Do the math for the values already in registers.
fmha
::
gemm_cl
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
// fmha::gemm_cl
<__half>
(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
}
// Do the final stage of math.
// Do the final stage of math.
{
{
int
ki
=
Mma_tile_dq
::
MMAS_K
;
int
ki
=
Mma_tile_dq
::
MMAS_K
;
fmha
::
gemm_cl
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dq
,
frag_p
[
ki
-
1
],
frag_kt
[(
ki
-
1
)
&
1
]);
// fmha::gemm_cl(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
// fmha::gemm_cl
<__half>
(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
}
}
static_assert
(
Gmem_tile_dq
::
LOOPS
==
1
);
static_assert
(
Gmem_tile_dq
::
LOOPS
==
1
);
...
@@ -485,13 +485,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -485,13 +485,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// Trigger the load from shared memory for the next series of Q values.
// Trigger the load from shared memory for the next series of Q values.
smem_dot
.
load
(
frag_dot
[
ki
&
1
],
ki
);
smem_dot
.
load
(
frag_dot
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
// Do the math for the values already in registers.
fmha
::
gemm_cl
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
}
}
// Do the final stage of math.
// Do the final stage of math.
{
{
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
fmha
::
gemm_cl
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dv
,
frag_s
[(
ki
-
1
)],
frag_dot
[(
ki
-
1
)
&
1
]);
}
}
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
...
@@ -542,13 +542,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
...
@@ -542,13 +542,13 @@ inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng
// Trigger the load from shared memory for the next series of Q values.
// Trigger the load from shared memory for the next series of Q values.
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
smem_qt
.
load
(
frag_qt
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
// Do the math for the values already in registers.
fmha
::
gemm_cl
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
}
// Do the final stage of math.
// Do the final stage of math.
{
{
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
int
ki
=
Mma_tile_dkv
::
MMAS_K
;
fmha
::
gemm_cl
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_dk
,
frag_dpt
[(
ki
-
1
)],
frag_qt
[(
ki
-
1
)
&
1
]);
}
}
// Make sure dQ is in shared memory.
// Make sure dQ is in shared memory.
...
...
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
View file @
6a77a6da
...
@@ -115,12 +115,12 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
...
@@ -115,12 +115,12 @@ struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
// Trigger the load from shared memory for the next series of Q values.
// Trigger the load from shared memory for the next series of Q values.
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
// Do the math for the values already in registers.
fmha
::
gemm_cl
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
fmha
::
gemm_cl
<
__half
>
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
}
// Do the final stage of math.
// Do the final stage of math.
{
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm_cl
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
fmha
::
gemm_cl
<
__half
>
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)]);
}
}
}
}
...
@@ -175,12 +175,12 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
...
@@ -175,12 +175,12 @@ struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
Base
::
smem_q
.
load
(
Base
::
frag_q
[
ki
&
1
],
ki
);
Base
::
smem_k
.
load
(
frag_k
[
ki
&
1
],
ki
);
Base
::
smem_k
.
load
(
frag_k
[
ki
&
1
],
ki
);
// Do the math for the values already in registers.
// Do the math for the values already in registers.
fmha
::
gemm_cl
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
}
// Do the final stage of math.
// Do the final stage of math.
{
{
int
ki
=
Mma_tile_p
::
MMAS_K
;
int
ki
=
Mma_tile_p
::
MMAS_K
;
fmha
::
gemm_cl
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
fmha
::
gemm_cl
<
__half
>
(
acc_p
,
Base
::
frag_q
[(
ki
-
1
)
&
1
],
frag_k
[(
ki
-
1
)
&
1
]);
}
}
}
}
...
@@ -494,7 +494,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
...
@@ -494,7 +494,7 @@ inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const i
// Do this part of O = P^T * V^T.
// Do this part of O = P^T * V^T.
#pragma unroll
#pragma unroll
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
for
(
int
ki
=
0
;
ki
<
Mma_tile_o
::
MMAS_K
;
++
ki
)
{
fmha
::
gemm_cl
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
fmha
::
gemm_cl
<
__half
>
(
acc_o
,
frag_p
[
ki
],
frag_v
[
ki
]);
// if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
// float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki]));
// float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki]));
// float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki]));
// float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki]));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment