Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
943464cc
Unverified
Commit
943464cc
authored
Apr 16, 2025
by
Jeffrey Morgan
Committed by
GitHub
Apr 16, 2025
Browse files
llama: update to commit 71e90e88 (#10192)
parent
369de832
Changes
157
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1077 additions
and
447 deletions
+1077
-447
ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
+60
-34
ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu
ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu
+37
-40
ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu
ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu
+48
-39
ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh
ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh
+40
-47
ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh
ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh
+52
-49
ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu
ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu
+100
-114
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
+20
-12
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
+98
-18
ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
+2
-0
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
+3
-3
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
+46
-30
ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu
ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu
+1
-1
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
+144
-59
ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
+116
-0
ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh
ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh
+2
-0
ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
+1
-1
ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu
ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu
+148
-0
ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cuh
ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cuh
+3
-0
ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu
ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu
+153
-0
ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cuh
ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cuh
+3
-0
No files found.
ml/backend/ggml/ggml/src/ggml-cuda/fattn-mma-f16.cuh
View file @
943464cc
...
...
@@ -406,6 +406,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#endif // CP_ASYNC_AVAILABLE
#else
GGML_UNUSED
(
Q_f2
);
GGML_UNUSED
(
K_h2
);
GGML_UNUSED
(
V_h2
);
GGML_UNUSED
(
mask_h2
);
GGML_UNUSED
(
dstk
);
GGML_UNUSED
(
dstk_fixup
);
GGML_UNUSED
(
scale
);
GGML_UNUSED
(
slope
);
GGML_UNUSED
(
logit_softcap
);
GGML_UNUSED
(
ne01
);
GGML_UNUSED
(
ne02
);
GGML_UNUSED
(
stride_KV
);
GGML_UNUSED
(
stride_mask
);
GGML_UNUSED
(
jt
);
GGML_UNUSED
(
tile_K
);
GGML_UNUSED
(
stride_mask
);
GGML_UNUSED
(
jt
);
GGML_UNUSED
(
tile_K
);
GGML_UNUSED
(
tile_V
);
GGML_UNUSED
(
tile_mask
);
GGML_UNUSED
(
Q_B
);
GGML_UNUSED
(
VKQ_C
);
GGML_UNUSED
(
KQ_max
);
GGML_UNUSED
(
KQ_rowsum
);
GGML_UNUSED
(
kb0
);
NO_DEVICE_CODE
;
#endif // NEW_MMA_AVAILABLE
}
...
...
@@ -797,6 +806,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
__syncthreads
();
}
#else
GGML_UNUSED
(
Q_f2
);
GGML_UNUSED
(
K_h2
);
GGML_UNUSED
(
V_h2
);
GGML_UNUSED
(
mask_h2
);
GGML_UNUSED
(
dstk
);
GGML_UNUSED
(
dstk_fixup
);
GGML_UNUSED
(
scale
);
GGML_UNUSED
(
slope
);
GGML_UNUSED
(
logit_softcap
);
GGML_UNUSED
(
ne01
);
GGML_UNUSED
(
ne02
);
GGML_UNUSED
(
stride_Q1
);
GGML_UNUSED
(
stride_Q2
);
GGML_UNUSED
(
stride_KV
);
GGML_UNUSED
(
stride_mask
);
GGML_UNUSED
(
jt
);
GGML_UNUSED
(
kb0_start
);
GGML_UNUSED
(
kb0_stop
);
NO_DEVICE_CODE
;
#endif // NEW_MMA_AVAILABLE
}
...
...
@@ -931,6 +946,16 @@ static __global__ void flash_attn_ext_f16(
(
Q_f2
,
K_h2
,
V_h2
,
mask_h2
,
dstk
,
dst_meta
,
scale
,
slope
,
logit_softcap
,
ne01
,
ne02
,
stride_Q1
,
stride_Q2
,
stride_KV
,
stride_mask
,
jt
,
kb0_start_kernel
,
kb0_stop_kernel
);
#else
GGML_UNUSED
(
Q
);
GGML_UNUSED
(
K
);
GGML_UNUSED
(
V
);
GGML_UNUSED
(
mask
);
GGML_UNUSED
(
dst
);
GGML_UNUSED
(
dst_meta
);
GGML_UNUSED
(
scale
);
GGML_UNUSED
(
max_bias
);
GGML_UNUSED
(
m0
);
GGML_UNUSED
(
m1
);
GGML_UNUSED
(
n_head_log2
);
GGML_UNUSED
(
logit_softcap
);
GGML_UNUSED
(
ne00
);
GGML_UNUSED
(
ne01
);
GGML_UNUSED
(
ne02
);
GGML_UNUSED
(
ne03
);
GGML_UNUSED
(
ne10
);
GGML_UNUSED
(
ne11
);
GGML_UNUSED
(
ne12
);
GGML_UNUSED
(
ne13
);
GGML_UNUSED
(
ne31
);
GGML_UNUSED
(
nb31
);
GGML_UNUSED
(
nb01
);
GGML_UNUSED
(
nb02
);
GGML_UNUSED
(
nb03
);
GGML_UNUSED
(
nb11
);
GGML_UNUSED
(
nb12
);
GGML_UNUSED
(
nb13
);
GGML_UNUSED
(
nb21
);
GGML_UNUSED
(
nb22
);
GGML_UNUSED
(
nb23
);
GGML_UNUSED
(
ne0
);
GGML_UNUSED
(
ne1
);
GGML_UNUSED
(
ne2
);
GGML_UNUSED
(
ne3
);
NO_DEVICE_CODE
;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
}
...
...
@@ -970,7 +995,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
fattn_kernel
=
flash_attn_ext_f16
<
D
,
ncols1
,
ncols2
,
nwarps
,
KQ_per_iter
,
ntiles
,
use_logit_softcap
>
;
}
launch_fattn
<
D
,
ncols1
,
ncols2
,
0
,
KQ_per_iter
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared_total
,
true
,
true
);
launch_fattn
<
D
,
ncols1
,
ncols2
,
KQ_per_iter
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared_total
,
FATTN_KQ_STRIDE
,
true
,
true
,
true
);
}
...
...
@@ -984,38 +1010,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
64
,
8
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
80
,
8
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
96
,
8
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
112
,
8
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
128
,
8
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
256
,
8
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
64
,
16
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
80
,
16
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
96
,
16
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
112
,
16
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
128
,
16
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
256
,
16
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
64
,
32
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
80
,
32
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
96
,
32
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
112
,
32
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
128
,
32
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
256
,
32
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
64
,
64
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
80
,
64
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
96
,
64
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
112
,
64
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
128
,
64
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
256
,
64
)
;
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
64
,
8
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
80
,
8
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
96
,
8
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
112
,
8
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
128
,
8
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
256
,
8
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
64
,
16
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
80
,
16
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
96
,
16
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
112
,
16
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
128
,
16
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
256
,
16
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
64
,
32
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
80
,
32
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
96
,
32
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
112
,
32
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
128
,
32
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
256
,
32
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
64
,
64
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
80
,
64
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
96
,
64
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
112
,
64
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
128
,
64
)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2
(
256
,
64
)
// Kernels with ncols == 128 are only 4% faster due to register pressure.
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128)
;
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128)
;
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128)
;
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128)
;
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128)
;
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128)
;
// Needs too much shared memory.
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory.
ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f16.cu
View file @
943464cc
...
...
@@ -4,7 +4,7 @@
#define FATTN_KQ_STRIDE_TILE_F16 64
template
<
int
D
,
int
ncols
,
int
nwarps
,
int
parallel_blocks
,
bool
use_logit_softcap
>
// D == head size
template
<
int
D
,
int
ncols
,
int
nwarps
,
bool
use_logit_softcap
>
// D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__
(
nwarps
*
WARP_SIZE
,
1
)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
...
...
@@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f16(
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const
int
ic0
=
(
blockIdx
.
x
/
parallel_blocks
)
*
ncols
;
// Index of the Q/QKV column to work on.
const
int
ip
=
blockIdx
.
x
%
parallel_blocks
;
// Index in group of blocks running for the same column in parallel.
const
int
ic0
=
blockIdx
.
x
*
ncols
;
// Index of the Q/QKV column to work on.
const
int
gqa_ratio
=
ne02
/
ne12
;
// With grouped query attention there are > 1 Q matrices per K, V matrix.
const
float2
*
Q_f2
=
(
const
float2
*
)
(
Q
+
nb02
*
blockIdx
.
y
+
nb01
*
ic0
);
const
half2
*
K_h2
=
(
const
half2
*
)
(
K
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
const
half2
*
V_h2
=
(
const
half2
*
)
(
V
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
// K and V have same shape
const
float2
*
Q_f2
=
(
const
float2
*
)
(
Q
+
nb02
*
blockIdx
.
z
+
nb01
*
ic0
);
const
half2
*
K_h2
=
(
const
half2
*
)
(
K
+
nb12
*
(
blockIdx
.
z
/
gqa_ratio
));
const
half2
*
V_h2
=
(
const
half2
*
)
(
V
+
nb12
*
(
blockIdx
.
z
/
gqa_ratio
));
// K and V have same shape
const
half
*
maskh
=
(
const
half
*
)
mask
+
ne11
*
ic0
;
const
int
stride_KV2
=
nb11
/
sizeof
(
half2
);
const
float
slopef
=
get_alibi_slope
(
max_bias
,
blockIdx
.
y
,
n_head_log2
,
m0
,
m1
);
const
float
slopef
=
get_alibi_slope
(
max_bias
,
blockIdx
.
z
,
n_head_log2
,
m0
,
m1
);
const
half
slopeh
=
__float2half
(
slopef
);
static_assert
(
D
%
(
2
*
WARP_SIZE
)
==
0
,
"D not divisible by 2*WARP_SIZE == 64."
);
...
...
@@ -105,8 +104,7 @@ static __global__ void flash_attn_tile_ext_f16(
__syncthreads
();
const
int
k_start
=
parallel_blocks
==
1
?
0
:
ip
*
FATTN_KQ_STRIDE_TILE_F16
;
for
(
int
k_VKQ_0
=
k_start
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
parallel_blocks
*
FATTN_KQ_STRIDE_TILE_F16
)
{
for
(
int
k_VKQ_0
=
blockIdx
.
y
*
FATTN_KQ_STRIDE_TILE_F16
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
gridDim
.
y
*
FATTN_KQ_STRIDE_TILE_F16
)
{
// Calculate KQ tile and keep track of new maximum KQ values:
half
kqmax_new
[
ncols
/
nwarps
];
...
...
@@ -271,24 +269,36 @@ static __global__ void flash_attn_tile_ext_f16(
const
int
i0
=
i00
+
2
*
threadIdx
.
x
;
half2
dst_val
=
VKQ
[
j_VKQ_0
/
nwarps
][
i0
/
(
2
*
WARP_SIZE
)];
if
(
parallel_blocks
==
1
)
{
if
(
gridDim
.
y
==
1
)
{
dst_val
/=
__half2half2
(
kqsum_j
);
}
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
parallel_blocks
+
ip
;
dst
[
j_dst
*
D
*
gridDim
.
y
+
D
*
blockIdx
.
y
+
i0
+
0
]
=
__low2float
(
dst_val
);
dst
[
j_dst
*
D
*
gridDim
.
y
+
D
*
blockIdx
.
y
+
i0
+
1
]
=
__high2float
(
dst_val
);
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
gridDim
.
y
+
blockIdx
.
y
;
dst
[
j_dst
*
D
*
gridDim
.
z
+
D
*
blockIdx
.
z
+
i0
+
0
]
=
__low2float
(
dst_val
);
dst
[
j_dst
*
D
*
gridDim
.
z
+
D
*
blockIdx
.
z
+
i0
+
1
]
=
__high2float
(
dst_val
);
}
if
(
parallel_blocks
!=
1
&&
threadIdx
.
x
==
0
)
{
dst_meta
[(
ic0
+
j_VKQ
)
*
gridDim
.
y
*
parallel_blocks
+
blockIdx
.
y
*
parallel_blocks
+
ip
]
=
make_float2
(
kqmax
[
j_VKQ_0
/
nwarps
],
kqsum_j
);
if
(
gridDim
.
y
!=
1
&&
threadIdx
.
x
==
0
)
{
dst_meta
[(
(
ic0
+
j_VKQ
)
*
gridDim
.
z
+
blockIdx
.
z
)
*
gridDim
.
y
+
blockIdx
.
y
]
=
make_float2
(
kqmax
[
j_VKQ_0
/
nwarps
],
kqsum_j
);
}
}
#else
NO_DEVICE_CODE
;
GGML_UNUSED
(
Q
);
GGML_UNUSED
(
K
);
GGML_UNUSED
(
V
);
GGML_UNUSED
(
mask
);
GGML_UNUSED
(
dst
);
GGML_UNUSED
(
dst_meta
);
GGML_UNUSED
(
scale
);
GGML_UNUSED
(
max_bias
);
GGML_UNUSED
(
m0
);
GGML_UNUSED
(
m1
);
GGML_UNUSED
(
n_head_log2
);
GGML_UNUSED
(
logit_softcap
);
GGML_UNUSED
(
ne00
);
GGML_UNUSED
(
ne01
);
GGML_UNUSED
(
ne02
);
GGML_UNUSED
(
ne03
);
GGML_UNUSED
(
ne10
);
GGML_UNUSED
(
ne11
);
GGML_UNUSED
(
ne12
);
GGML_UNUSED
(
ne13
);
GGML_UNUSED
(
ne31
);
GGML_UNUSED
(
nb31
);
GGML_UNUSED
(
nb01
);
GGML_UNUSED
(
nb02
);
GGML_UNUSED
(
nb03
);
GGML_UNUSED
(
nb11
);
GGML_UNUSED
(
nb12
);
GGML_UNUSED
(
nb13
);
GGML_UNUSED
(
nb21
);
GGML_UNUSED
(
nb22
);
GGML_UNUSED
(
nb23
);
GGML_UNUSED
(
ne0
);
GGML_UNUSED
(
ne1
);
GGML_UNUSED
(
ne2
);
GGML_UNUSED
(
ne3
);
NO_DEVICE_CODE
;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
template
<
int
cols_per_block
,
int
parallel_blocks
,
bool
use_logit_softcap
>
template
<
int
cols_per_block
,
bool
use_logit_softcap
>
void
launch_fattn_tile_f16_64_128
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
switch
(
Q
->
ne
[
0
])
{
...
...
@@ -296,15 +306,17 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
constexpr
int
D
=
64
;
constexpr
int
nwarps
=
8
;
constexpr
size_t
nbytes_shared
=
0
;
fattn_kernel_t
fattn_kernel
=
flash_attn_tile_ext_f16
<
D
,
cols_per_block
,
nwarps
,
parallel_blocks
,
use_logit_softcap
>
;
launch_fattn
<
D
,
cols_per_block
,
1
,
parallel_blocks
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared
,
true
,
true
);
fattn_kernel_t
fattn_kernel
=
flash_attn_tile_ext_f16
<
D
,
cols_per_block
,
nwarps
,
use_logit_softcap
>
;
launch_fattn
<
D
,
cols_per_block
,
1
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared
,
FATTN_KQ_STRIDE_TILE_F16
,
true
,
true
,
false
);
}
break
;
case
128
:
{
constexpr
int
D
=
128
;
constexpr
int
nwarps
=
8
;
constexpr
size_t
nbytes_shared
=
0
;
fattn_kernel_t
fattn_kernel
=
flash_attn_tile_ext_f16
<
D
,
cols_per_block
,
nwarps
,
parallel_blocks
,
use_logit_softcap
>
;
launch_fattn
<
D
,
cols_per_block
,
1
,
parallel_blocks
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared
,
true
,
true
);
fattn_kernel_t
fattn_kernel
=
flash_attn_tile_ext_f16
<
D
,
cols_per_block
,
nwarps
,
use_logit_softcap
>
;
launch_fattn
<
D
,
cols_per_block
,
1
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared
,
FATTN_KQ_STRIDE_TILE_F16
,
true
,
true
,
false
);
}
break
;
default:
{
GGML_ABORT
(
"FlashAttention without tensor cores only supports head sizes 64 and 128."
);
...
...
@@ -324,37 +336,22 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
if
(
Q
->
ne
[
1
]
<=
16
)
{
constexpr
int
cols_per_block
=
16
;
constexpr
int
parallel_blocks
=
4
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
launch_fattn_tile_f16_64_128
<
cols_per_block
,
parallel_blocks
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
launch_fattn_tile_f16_64_128
<
cols_per_block
,
parallel_blocks
,
use_logit_softcap
>
(
ctx
,
dst
);
}
return
;
}
if
(
Q
->
ne
[
1
]
<=
32
)
{
constexpr
int
cols_per_block
=
32
;
constexpr
int
parallel_blocks
=
4
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
launch_fattn_tile_f16_64_128
<
cols_per_block
,
parallel_blocks
,
use_logit_softcap
>
(
ctx
,
dst
);
launch_fattn_tile_f16_64_128
<
cols_per_block
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
launch_fattn_tile_f16_64_128
<
cols_per_block
,
parallel_blocks
,
use_logit_softcap
>
(
ctx
,
dst
);
launch_fattn_tile_f16_64_128
<
cols_per_block
,
use_logit_softcap
>
(
ctx
,
dst
);
}
return
;
}
constexpr
int
cols_per_block
=
32
;
constexpr
int
parallel_blocks
=
1
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
launch_fattn_tile_f16_64_128
<
cols_per_block
,
parallel_blocks
,
use_logit_softcap
>
(
ctx
,
dst
);
launch_fattn_tile_f16_64_128
<
cols_per_block
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
launch_fattn_tile_f16_64_128
<
cols_per_block
,
parallel_blocks
,
use_logit_softcap
>
(
ctx
,
dst
);
launch_fattn_tile_f16_64_128
<
cols_per_block
,
use_logit_softcap
>
(
ctx
,
dst
);
}
}
ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile-f32.cu
View file @
943464cc
...
...
@@ -4,7 +4,7 @@
#define FATTN_KQ_STRIDE_TILE_F32 32
template
<
int
D
,
int
ncols
,
int
nwarps
,
int
parallel_blocks
,
bool
use_logit_softcap
>
// D == head size
template
<
int
D
,
int
ncols
,
int
nwarps
,
bool
use_logit_softcap
>
// D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__
(
nwarps
*
WARP_SIZE
,
1
)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
...
...
@@ -52,24 +52,35 @@ static __global__ void flash_attn_tile_ext_f32(
return
;
#endif // FP16_MMA_AVAILABLE
if
(
use_logit_softcap
&&
!
(
D
==
128
||
D
==
256
))
{
GGML_UNUSED
(
Q
);
GGML_UNUSED
(
K
);
GGML_UNUSED
(
V
);
GGML_UNUSED
(
mask
);
GGML_UNUSED
(
dst
);
GGML_UNUSED
(
dst_meta
);
GGML_UNUSED
(
scale
);
GGML_UNUSED
(
max_bias
);
GGML_UNUSED
(
m0
);
GGML_UNUSED
(
m1
);
GGML_UNUSED
(
n_head_log2
);
GGML_UNUSED
(
logit_softcap
);
GGML_UNUSED
(
ne00
);
GGML_UNUSED
(
ne01
);
GGML_UNUSED
(
ne02
);
GGML_UNUSED
(
ne03
);
GGML_UNUSED
(
ne10
);
GGML_UNUSED
(
ne11
);
GGML_UNUSED
(
ne12
);
GGML_UNUSED
(
ne13
);
GGML_UNUSED
(
ne31
);
GGML_UNUSED
(
nb31
);
GGML_UNUSED
(
nb01
);
GGML_UNUSED
(
nb02
);
GGML_UNUSED
(
nb03
);
GGML_UNUSED
(
nb11
);
GGML_UNUSED
(
nb12
);
GGML_UNUSED
(
nb13
);
GGML_UNUSED
(
nb21
);
GGML_UNUSED
(
nb22
);
GGML_UNUSED
(
nb23
);
GGML_UNUSED
(
ne0
);
GGML_UNUSED
(
ne1
);
GGML_UNUSED
(
ne2
);
GGML_UNUSED
(
ne3
);
NO_DEVICE_CODE
;
return
;
}
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const
int
ic0
=
(
blockIdx
.
x
/
parallel_blocks
)
*
ncols
;
// Index of the Q/QKV column to work on.
const
int
ip
=
blockIdx
.
x
%
parallel_blocks
;
// Index in group of blocks running for the same column in parallel.
const
int
ic0
=
blockIdx
.
x
*
ncols
;
// Index of the Q/QKV column to work on.
const
int
gqa_ratio
=
ne02
/
ne12
;
// With grouped query attention there are > 1 Q matrices per K, V matrix.
const
float2
*
Q_f2
=
(
const
float2
*
)
(
Q
+
nb02
*
blockIdx
.
y
+
nb01
*
ic0
);
const
half2
*
K_h2
=
(
const
half2
*
)
(
K
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
const
half2
*
V_h2
=
(
const
half2
*
)
(
V
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
// K and V have same shape
const
float2
*
Q_f2
=
(
const
float2
*
)
(
Q
+
nb02
*
blockIdx
.
z
+
nb01
*
ic0
);
const
half2
*
K_h2
=
(
const
half2
*
)
(
K
+
nb12
*
(
blockIdx
.
z
/
gqa_ratio
));
const
half2
*
V_h2
=
(
const
half2
*
)
(
V
+
nb12
*
(
blockIdx
.
z
/
gqa_ratio
));
// K and V have same shape
const
half
*
maskh
=
(
const
half
*
)
mask
+
ne11
*
ic0
;
const
int
stride_KV2
=
nb11
/
sizeof
(
half2
);
const
float
slope
=
get_alibi_slope
(
max_bias
,
blockIdx
.
y
,
n_head_log2
,
m0
,
m1
);
const
float
slope
=
get_alibi_slope
(
max_bias
,
blockIdx
.
z
,
n_head_log2
,
m0
,
m1
);
static_assert
(
D
%
(
2
*
WARP_SIZE
)
==
0
,
"D not divisible by 2*WARP_SIZE == 64."
);
...
...
@@ -103,8 +114,7 @@ static __global__ void flash_attn_tile_ext_f32(
__syncthreads
();
const
int
k_start
=
parallel_blocks
==
1
?
0
:
ip
*
FATTN_KQ_STRIDE_TILE_F32
;
for
(
int
k_VKQ_0
=
k_start
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
parallel_blocks
*
FATTN_KQ_STRIDE_TILE_F32
)
{
for
(
int
k_VKQ_0
=
blockIdx
.
y
*
FATTN_KQ_STRIDE_TILE_F32
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
gridDim
.
y
*
FATTN_KQ_STRIDE_TILE_F32
)
{
// Calculate KQ tile and keep track of new maximum KQ values:
float
kqmax_new
[
ncols
/
nwarps
];
...
...
@@ -269,25 +279,37 @@ static __global__ void flash_attn_tile_ext_f32(
const
int
i0
=
i00
+
2
*
threadIdx
.
x
;
float2
dst_val
=
VKQ
[
j_VKQ_0
/
nwarps
][
i0
/
(
2
*
WARP_SIZE
)];
if
(
parallel_blocks
==
1
)
{
if
(
gridDim
.
y
==
1
)
{
dst_val
.
x
/=
kqsum_j
;
dst_val
.
y
/=
kqsum_j
;
}
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
parallel_blocks
+
ip
;
dst
[
j_dst
*
D
*
gridDim
.
y
+
D
*
blockIdx
.
y
+
i0
+
0
]
=
dst_val
.
x
;
dst
[
j_dst
*
D
*
gridDim
.
y
+
D
*
blockIdx
.
y
+
i0
+
1
]
=
dst_val
.
y
;
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
gridDim
.
y
+
blockIdx
.
y
;
dst
[
j_dst
*
D
*
gridDim
.
z
+
D
*
blockIdx
.
z
+
i0
+
0
]
=
dst_val
.
x
;
dst
[
j_dst
*
D
*
gridDim
.
z
+
D
*
blockIdx
.
z
+
i0
+
1
]
=
dst_val
.
y
;
}
if
(
parallel_blocks
!=
1
&&
threadIdx
.
x
==
0
)
{
dst_meta
[(
ic0
+
j_VKQ
)
*
gridDim
.
y
*
parallel_blocks
+
blockIdx
.
y
*
parallel_blocks
+
ip
]
=
make_float2
(
kqmax
[
j_VKQ_0
/
nwarps
],
kqsum_j
);
if
(
gridDim
.
y
!=
1
&&
threadIdx
.
x
==
0
)
{
dst_meta
[(
(
ic0
+
j_VKQ
)
*
gridDim
.
z
+
blockIdx
.
z
)
*
gridDim
.
y
+
blockIdx
.
y
]
=
make_float2
(
kqmax
[
j_VKQ_0
/
nwarps
],
kqsum_j
);
}
}
#else
GGML_UNUSED
(
Q
);
GGML_UNUSED
(
K
);
GGML_UNUSED
(
V
);
GGML_UNUSED
(
mask
);
GGML_UNUSED
(
dst
);
GGML_UNUSED
(
dst_meta
);
GGML_UNUSED
(
scale
);
GGML_UNUSED
(
max_bias
);
GGML_UNUSED
(
m0
);
GGML_UNUSED
(
m1
);
GGML_UNUSED
(
n_head_log2
);
GGML_UNUSED
(
logit_softcap
);
GGML_UNUSED
(
ne00
);
GGML_UNUSED
(
ne01
);
GGML_UNUSED
(
ne02
);
GGML_UNUSED
(
ne03
);
GGML_UNUSED
(
ne10
);
GGML_UNUSED
(
ne11
);
GGML_UNUSED
(
ne12
);
GGML_UNUSED
(
ne13
);
GGML_UNUSED
(
ne31
);
GGML_UNUSED
(
nb31
);
GGML_UNUSED
(
nb01
);
GGML_UNUSED
(
nb02
);
GGML_UNUSED
(
nb03
);
GGML_UNUSED
(
nb11
);
GGML_UNUSED
(
nb12
);
GGML_UNUSED
(
nb13
);
GGML_UNUSED
(
nb21
);
GGML_UNUSED
(
nb22
);
GGML_UNUSED
(
nb23
);
GGML_UNUSED
(
ne0
);
GGML_UNUSED
(
ne1
);
GGML_UNUSED
(
ne2
);
GGML_UNUSED
(
ne3
);
NO_DEVICE_CODE
;
#endif // FLASH_ATTN_AVAILABLE
}
template
<
int
cols_per_block
,
int
parallel_blocks
,
bool
use_logit_softcap
>
template
<
int
cols_per_block
,
bool
use_logit_softcap
>
void
launch_fattn_tile_f32_64_128
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
switch
(
Q
->
ne
[
0
])
{
...
...
@@ -295,15 +317,17 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
constexpr
int
D
=
64
;
constexpr
int
nwarps
=
8
;
constexpr
size_t
nbytes_shared
=
0
;
fattn_kernel_t
fattn_kernel
=
flash_attn_tile_ext_f32
<
D
,
cols_per_block
,
nwarps
,
parallel_blocks
,
use_logit_softcap
>
;
launch_fattn
<
D
,
cols_per_block
,
1
,
parallel_blocks
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared
,
true
,
true
);
fattn_kernel_t
fattn_kernel
=
flash_attn_tile_ext_f32
<
D
,
cols_per_block
,
nwarps
,
use_logit_softcap
>
;
launch_fattn
<
D
,
cols_per_block
,
1
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared
,
FATTN_KQ_STRIDE_TILE_F32
,
true
,
true
,
false
);
}
break
;
case
128
:
{
constexpr
int
D
=
128
;
constexpr
int
nwarps
=
8
;
constexpr
size_t
nbytes_shared
=
0
;
fattn_kernel_t
fattn_kernel
=
flash_attn_tile_ext_f32
<
D
,
cols_per_block
,
nwarps
,
parallel_blocks
,
use_logit_softcap
>
;
launch_fattn
<
D
,
cols_per_block
,
1
,
parallel_blocks
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared
,
true
,
true
);
fattn_kernel_t
fattn_kernel
=
flash_attn_tile_ext_f32
<
D
,
cols_per_block
,
nwarps
,
use_logit_softcap
>
;
launch_fattn
<
D
,
cols_per_block
,
1
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared
,
FATTN_KQ_STRIDE_TILE_F32
,
true
,
true
,
false
);
}
break
;
default:
{
GGML_ABORT
(
"FlashAttention without tensor cores only supports head sizes 64 and 128."
);
...
...
@@ -320,37 +344,22 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten
if
(
Q
->
ne
[
1
]
<=
16
)
{
constexpr
int
cols_per_block
=
16
;
constexpr
int
parallel_blocks
=
4
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
launch_fattn_tile_f32_64_128
<
cols_per_block
,
parallel_blocks
,
use_logit_softcap
>
(
ctx
,
dst
);
launch_fattn_tile_f32_64_128
<
cols_per_block
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
launch_fattn_tile_f32_64_128
<
cols_per_block
,
parallel_blocks
,
use_logit_softcap
>
(
ctx
,
dst
);
}
return
;
}
if
(
Q
->
ne
[
1
]
<=
32
)
{
constexpr
int
cols_per_block
=
32
;
constexpr
int
parallel_blocks
=
4
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
launch_fattn_tile_f32_64_128
<
cols_per_block
,
parallel_blocks
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
launch_fattn_tile_f32_64_128
<
cols_per_block
,
parallel_blocks
,
use_logit_softcap
>
(
ctx
,
dst
);
launch_fattn_tile_f32_64_128
<
cols_per_block
,
use_logit_softcap
>
(
ctx
,
dst
);
}
return
;
}
constexpr
int
cols_per_block
=
32
;
constexpr
int
parallel_blocks
=
1
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
launch_fattn_tile_f32_64_128
<
cols_per_block
,
parallel_blocks
,
use_logit_softcap
>
(
ctx
,
dst
);
launch_fattn_tile_f32_64_128
<
cols_per_block
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
launch_fattn_tile_f32_64_128
<
cols_per_block
,
parallel_blocks
,
use_logit_softcap
>
(
ctx
,
dst
);
launch_fattn_tile_f32_64_128
<
cols_per_block
,
use_logit_softcap
>
(
ctx
,
dst
);
}
}
ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f16.cuh
View file @
943464cc
#include "common.cuh"
#include "fattn-common.cuh"
template
<
int
D
,
int
ncols
,
int
parallel_blocks
,
ggml_type
type_K
,
ggml_type
type_V
,
bool
use_logit_softcap
>
// D == head size
template
<
int
D
,
int
ncols
,
ggml_type
type_K
,
ggml_type
type_V
,
bool
use_logit_softcap
>
// D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__
(
D
,
1
)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
...
...
@@ -55,17 +55,16 @@ static __global__ void flash_attn_vec_ext_f16(
constexpr
bool
Q_q8_1
=
type_K
!=
GGML_TYPE_F16
;
constexpr
dequantize_1_f16_t
dequantize_1_v
=
get_dequantize_1_f16
(
type_V
);
const
int
ic0
=
(
blockIdx
.
x
/
parallel_blocks
)
*
ncols
;
// Index of the Q/QKV column to work on.
const
int
ip
=
blockIdx
.
x
%
parallel_blocks
;
// Index in group of blocks running for the same column in parallel.
const
int
ic0
=
blockIdx
.
x
*
ncols
;
// Index of the Q/QKV column to work on.
const
int
gqa_ratio
=
ne02
/
ne12
;
// With grouped query attention there are > 1 Q matrices per K, V matrix.
Q
+=
nb02
*
blockIdx
.
y
+
nb01
*
ic0
;
K
+=
nb12
*
(
blockIdx
.
y
/
gqa_ratio
);
V
+=
nb22
*
(
blockIdx
.
y
/
gqa_ratio
);
Q
+=
nb02
*
blockIdx
.
z
+
nb01
*
ic0
;
K
+=
nb12
*
(
blockIdx
.
z
/
gqa_ratio
);
V
+=
nb22
*
(
blockIdx
.
z
/
gqa_ratio
);
const
half
*
maskh
=
(
const
half
*
)
mask
+
ne11
*
ic0
;
const
float
slopef
=
get_alibi_slope
(
max_bias
,
blockIdx
.
y
,
n_head_log2
,
m0
,
m1
);
const
float
slopef
=
get_alibi_slope
(
max_bias
,
blockIdx
.
z
,
n_head_log2
,
m0
,
m1
);
const
half
slopeh
=
__float2half
(
slopef
);
static_assert
(
D
%
(
2
*
WARP_SIZE
)
==
0
,
"D not divisible by 2*WARP_SIZE == 64."
);
...
...
@@ -172,8 +171,7 @@ static __global__ void flash_attn_vec_ext_f16(
half2
VKQ
[
ncols
]
=
{{
0.0
f
,
0.0
f
}};
const
int
k_start
=
parallel_blocks
==
1
?
0
:
ip
*
D
;
for
(
int
k_VKQ_0
=
k_start
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
parallel_blocks
*
D
)
{
for
(
int
k_VKQ_0
=
blockIdx
.
y
*
D
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
gridDim
.
y
*
D
)
{
// Calculate KQ tile and keep track of new maximum KQ values:
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
...
...
@@ -283,29 +281,41 @@ static __global__ void flash_attn_vec_ext_f16(
kqsum
[
j_VKQ
]
=
warp_reduce_sum
((
float
)
kqsum
[
j_VKQ
]);
half
dst_val
=
(
__low2half
(
VKQ
[
j_VKQ
])
+
__high2half
(
VKQ
[
j_VKQ
]));
if
(
parallel_blocks
==
1
)
{
if
(
gridDim
.
y
==
1
)
{
dst_val
/=
kqsum
[
j_VKQ
];
}
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
parallel_blocks
+
ip
;
dst
[
j_dst
*
D
*
gridDim
.
y
+
D
*
blockIdx
.
y
+
tid
]
=
dst_val
;
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
gridDim
.
y
+
blockIdx
.
y
;
dst
[
j_dst
*
D
*
gridDim
.
z
+
D
*
blockIdx
.
z
+
tid
]
=
dst_val
;
}
if
(
parallel_blocks
!=
1
&&
tid
<
ncols
&&
(
ncols
<=
2
||
ic0
+
tid
<
ne01
))
{
dst_meta
[(
ic0
+
tid
)
*
gridDim
.
y
*
parallel_blocks
+
blockIdx
.
y
*
parallel_blocks
+
ip
]
=
make_float2
(
kqmax
[
tid
],
kqsum
[
tid
]);
if
(
gridDim
.
y
!=
1
&&
tid
<
ncols
&&
(
ncols
<=
2
||
ic0
+
tid
<
ne01
))
{
dst_meta
[(
(
ic0
+
tid
)
*
gridDim
.
z
+
blockIdx
.
z
)
*
gridDim
.
y
+
blockIdx
.
y
]
=
make_float2
(
kqmax
[
tid
],
kqsum
[
tid
]);
}
#else
NO_DEVICE_CODE
;
GGML_UNUSED
(
Q
);
GGML_UNUSED
(
K
);
GGML_UNUSED
(
V
);
GGML_UNUSED
(
mask
);
GGML_UNUSED
(
dst
);
GGML_UNUSED
(
dst_meta
);
GGML_UNUSED
(
scale
);
GGML_UNUSED
(
max_bias
);
GGML_UNUSED
(
m0
);
GGML_UNUSED
(
m1
);
GGML_UNUSED
(
n_head_log2
);
GGML_UNUSED
(
logit_softcap
);
GGML_UNUSED
(
ne00
);
GGML_UNUSED
(
ne01
);
GGML_UNUSED
(
ne02
);
GGML_UNUSED
(
ne03
);
GGML_UNUSED
(
ne10
);
GGML_UNUSED
(
ne11
);
GGML_UNUSED
(
ne12
);
GGML_UNUSED
(
ne13
);
GGML_UNUSED
(
ne31
);
GGML_UNUSED
(
nb31
);
GGML_UNUSED
(
nb01
);
GGML_UNUSED
(
nb02
);
GGML_UNUSED
(
nb03
);
GGML_UNUSED
(
nb11
);
GGML_UNUSED
(
nb12
);
GGML_UNUSED
(
nb13
);
GGML_UNUSED
(
nb21
);
GGML_UNUSED
(
nb22
);
GGML_UNUSED
(
nb23
);
GGML_UNUSED
(
ne0
);
GGML_UNUSED
(
ne1
);
GGML_UNUSED
(
ne2
);
GGML_UNUSED
(
ne3
);
NO_DEVICE_CODE
;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
template
<
int
D
,
int
cols_per_block
,
int
parallel_blocks
,
ggml_type
type_K
,
ggml_type
type_V
,
bool
use_logit_softcap
>
template
<
int
D
,
int
cols_per_block
,
ggml_type
type_K
,
ggml_type
type_V
,
bool
use_logit_softcap
>
void
ggml_cuda_flash_attn_ext_vec_f16_case_impl
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
constexpr
int
nwarps
=
D
/
WARP_SIZE
;
fattn_kernel_t
fattn_kernel
=
flash_attn_vec_ext_f16
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
;
fattn_kernel_t
fattn_kernel
=
flash_attn_vec_ext_f16
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
;
constexpr
bool
need_f16_K
=
D
!=
128
;
constexpr
bool
need_f16_V
=
D
!=
128
&&
D
!=
64
;
constexpr
size_t
nbytes_shared
=
0
;
launch_fattn
<
D
,
cols_per_block
,
1
,
parallel_blocks
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared
,
need_f16_K
,
need_f16_V
);
launch_fattn
<
D
,
cols_per_block
,
1
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared
,
D
,
need_f16_K
,
need_f16_V
,
false
);
}
template
<
int
D
,
ggml_type
type_K
,
ggml_type
type_V
>
...
...
@@ -325,65 +335,48 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
memcpy
(
&
logit_softcap
,
(
const
float
*
)
KQV
->
op_params
+
2
,
sizeof
(
float
));
if
(
Q
->
ne
[
1
]
==
1
)
{
constexpr
int
cols_per_block
=
1
;
constexpr
int
parallel_blocks
=
4
;
constexpr
int
cols_per_block
=
1
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
return
;
}
if
(
Q
->
ne
[
1
]
==
2
)
{
constexpr
int
cols_per_block
=
2
;
constexpr
int
parallel_blocks
=
4
;
constexpr
int
cols_per_block
=
2
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
return
;
}
if
(
Q
->
ne
[
1
]
<=
4
)
{
constexpr
int
cols_per_block
=
4
;
constexpr
int
parallel_blocks
=
4
;
constexpr
int
cols_per_block
=
4
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
return
;
}
if
(
Q
->
ne
[
1
]
<=
8
)
{
constexpr
int
cols_per_block
=
8
;
constexpr
int
parallel_blocks
=
4
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
return
;
}
constexpr
int
cols_per_block
=
8
;
constexpr
int
parallel_blocks
=
1
;
constexpr
int
cols_per_block
=
8
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f16_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
}
...
...
ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec-f32.cuh
View file @
943464cc
#include "common.cuh"
#include "fattn-common.cuh"
template
<
int
D
,
int
ncols
,
int
parallel_blocks
,
ggml_type
type_K
,
ggml_type
type_V
,
bool
use_logit_softcap
>
// D == head size
template
<
int
D
,
int
ncols
,
ggml_type
type_K
,
ggml_type
type_V
,
bool
use_logit_softcap
>
// D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__
(
D
,
1
)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
...
...
@@ -45,6 +45,18 @@ static __global__ void flash_attn_vec_ext_f32(
// Skip unused kernel variants for faster compilation:
if
(
use_logit_softcap
&&
!
(
D
==
128
||
D
==
256
))
{
GGML_UNUSED
(
Q
);
GGML_UNUSED
(
K
);
GGML_UNUSED
(
V
);
GGML_UNUSED
(
mask
);
GGML_UNUSED
(
dst
);
GGML_UNUSED
(
dst_meta
);
GGML_UNUSED
(
scale
);
GGML_UNUSED
(
max_bias
);
GGML_UNUSED
(
m0
);
GGML_UNUSED
(
m1
);
GGML_UNUSED
(
n_head_log2
);
GGML_UNUSED
(
logit_softcap
);
GGML_UNUSED
(
ne00
);
GGML_UNUSED
(
ne01
);
GGML_UNUSED
(
ne02
);
GGML_UNUSED
(
ne03
);
GGML_UNUSED
(
ne10
);
GGML_UNUSED
(
ne11
);
GGML_UNUSED
(
ne12
);
GGML_UNUSED
(
ne13
);
GGML_UNUSED
(
ne31
);
GGML_UNUSED
(
nb31
);
GGML_UNUSED
(
nb01
);
GGML_UNUSED
(
nb02
);
GGML_UNUSED
(
nb03
);
GGML_UNUSED
(
nb11
);
GGML_UNUSED
(
nb12
);
GGML_UNUSED
(
nb13
);
GGML_UNUSED
(
nb21
);
GGML_UNUSED
(
nb22
);
GGML_UNUSED
(
nb23
);
GGML_UNUSED
(
ne0
);
GGML_UNUSED
(
ne1
);
GGML_UNUSED
(
ne2
);
GGML_UNUSED
(
ne3
);
NO_DEVICE_CODE
;
return
;
}
...
...
@@ -55,16 +67,15 @@ static __global__ void flash_attn_vec_ext_f32(
constexpr
bool
Q_q8_1
=
type_K
!=
GGML_TYPE_F16
;
constexpr
dequantize_1_f32_t
dequantize_1_v
=
get_dequantize_1_f32
(
type_V
);
const
int
ic0
=
(
blockIdx
.
x
/
parallel_blocks
)
*
ncols
;
// Index of the Q/QKV column to work on.
const
int
ip
=
blockIdx
.
x
%
parallel_blocks
;
// Index in group of blocks running for the same column in parallel.
const
int
ic0
=
blockIdx
.
x
*
ncols
;
// Index of the Q/QKV column to work on.
const
int
gqa_ratio
=
ne02
/
ne12
;
// With grouped query attention there are > 1 Q matrices per K, V matrix.
Q
+=
nb02
*
blockIdx
.
y
+
nb01
*
ic0
;
K
+=
nb12
*
(
blockIdx
.
y
/
gqa_ratio
);
V
+=
nb22
*
(
blockIdx
.
y
/
gqa_ratio
);
// K and V have same shape
Q
+=
nb02
*
blockIdx
.
z
+
nb01
*
ic0
;
K
+=
nb12
*
(
blockIdx
.
z
/
gqa_ratio
);
V
+=
nb22
*
(
blockIdx
.
z
/
gqa_ratio
);
// K and V have same shape
const
half
*
maskh
=
(
const
half
*
)
mask
+
ne11
*
ic0
;
const
float
slope
=
get_alibi_slope
(
max_bias
,
blockIdx
.
y
,
n_head_log2
,
m0
,
m1
);
const
float
slope
=
get_alibi_slope
(
max_bias
,
blockIdx
.
z
,
n_head_log2
,
m0
,
m1
);
static_assert
(
D
%
(
2
*
WARP_SIZE
)
==
0
,
"D not divisible by 2*WARP_SIZE == 64."
);
constexpr
int
nwarps
=
D
/
WARP_SIZE
;
...
...
@@ -115,7 +126,7 @@ static __global__ void flash_attn_vec_ext_f32(
// Set memory to zero if out of bounds:
if
(
ncols
>
2
&&
ic0
+
j
>=
ne01
)
{
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
sizeof
(
int
);
i0
+=
WARP_SIZE
)
{
for
(
int
i0
=
0
;
i0
<
int
(
D
/
sizeof
(
int
)
)
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
tmp_q_i32
[
i
]
=
0
;
...
...
@@ -128,7 +139,7 @@ static __global__ void flash_attn_vec_ext_f32(
const
float
*
Q_f
=
(
const
float
*
)
(
Q
+
j
*
nb01
);
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
sizeof
(
int
);
i0
+=
WARP_SIZE
)
{
for
(
int
i0
=
0
;
i0
<
int
(
D
/
sizeof
(
int
)
)
;
i0
+=
WARP_SIZE
)
{
quantize_q8_1_to_shared
<
float2
>
(
Q_f
+
4
*
i0
,
scale
,
tmp_q_i32
,
tmp_q_ds
);
}
}
...
...
@@ -141,7 +152,7 @@ static __global__ void flash_attn_vec_ext_f32(
float2
*
tmp_q_ds
=
(
float2
*
)
(
tmp_q_i32
+
D
/
sizeof
(
int
));
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
sizeof
(
int
);
i0
+=
WARP_SIZE
)
{
for
(
int
i0
=
0
;
i0
<
int
(
D
/
sizeof
(
int
)
)
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
Q_i32
[
j
][
i0
/
WARP_SIZE
]
=
tmp_q_i32
[
i
];
...
...
@@ -167,8 +178,7 @@ static __global__ void flash_attn_vec_ext_f32(
float
VKQ
[
ncols
]
=
{
0.0
f
};
const
int
k_start
=
parallel_blocks
==
1
?
0
:
ip
*
D
;
for
(
int
k_VKQ_0
=
k_start
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
parallel_blocks
*
D
)
{
for
(
int
k_VKQ_0
=
blockIdx
.
y
*
D
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
gridDim
.
y
*
D
)
{
// Calculate KQ tile and keep track of new maximum KQ values:
float
kqmax_new_arr
[
ncols
];
...
...
@@ -268,29 +278,39 @@ static __global__ void flash_attn_vec_ext_f32(
kqsum
[
j_VKQ
]
=
warp_reduce_sum
(
kqsum
[
j_VKQ
]);
float
dst_val
=
VKQ
[
j_VKQ
];
if
(
parallel_blocks
==
1
)
{
if
(
gridDim
.
y
==
1
)
{
dst_val
/=
kqsum
[
j_VKQ
];
}
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
parallel_blocks
+
ip
;
dst
[
j_dst
*
D
*
gridDim
.
y
+
D
*
blockIdx
.
y
+
tid
]
=
dst_val
;
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
gridDim
.
y
+
blockIdx
.
y
;
dst
[
j_dst
*
D
*
gridDim
.
z
+
D
*
blockIdx
.
z
+
tid
]
=
dst_val
;
}
if
(
parallel_blocks
!=
1
&&
tid
<
ncols
&&
(
ncols
<=
2
||
ic0
+
tid
<
ne01
))
{
dst_meta
[(
ic0
+
tid
)
*
gridDim
.
y
*
parallel_blocks
+
blockIdx
.
y
*
parallel_blocks
+
ip
]
=
make_float2
(
kqmax
[
tid
],
kqsum
[
tid
]);
if
(
gridDim
.
y
!=
1
&&
tid
<
ncols
&&
(
ncols
<=
2
||
ic0
+
tid
<
ne01
))
{
dst_meta
[(
(
ic0
+
tid
)
*
gridDim
.
z
+
blockIdx
.
z
)
*
gridDim
.
y
+
blockIdx
.
y
]
=
make_float2
(
kqmax
[
tid
],
kqsum
[
tid
]);
}
#else
GGML_UNUSED
(
Q
);
GGML_UNUSED
(
K
);
GGML_UNUSED
(
V
);
GGML_UNUSED
(
mask
);
GGML_UNUSED
(
dst
);
GGML_UNUSED
(
dst_meta
);
GGML_UNUSED
(
scale
);
GGML_UNUSED
(
max_bias
);
GGML_UNUSED
(
m0
);
GGML_UNUSED
(
m1
);
GGML_UNUSED
(
n_head_log2
);
GGML_UNUSED
(
logit_softcap
);
GGML_UNUSED
(
ne00
);
GGML_UNUSED
(
ne01
);
GGML_UNUSED
(
ne02
);
GGML_UNUSED
(
ne03
);
GGML_UNUSED
(
ne10
);
GGML_UNUSED
(
ne11
);
GGML_UNUSED
(
ne12
);
GGML_UNUSED
(
ne13
);
GGML_UNUSED
(
ne31
);
GGML_UNUSED
(
nb31
);
GGML_UNUSED
(
nb01
);
GGML_UNUSED
(
nb02
);
GGML_UNUSED
(
nb03
);
GGML_UNUSED
(
nb11
);
GGML_UNUSED
(
nb12
);
GGML_UNUSED
(
nb13
);
GGML_UNUSED
(
nb21
);
GGML_UNUSED
(
nb22
);
GGML_UNUSED
(
nb23
);
GGML_UNUSED
(
ne0
);
GGML_UNUSED
(
ne1
);
GGML_UNUSED
(
ne2
);
GGML_UNUSED
(
ne3
);
NO_DEVICE_CODE
;
#endif // FLASH_ATTN_AVAILABLE
}
template
<
int
D
,
int
cols_per_block
,
int
parallel_blocks
,
ggml_type
type_K
,
ggml_type
type_V
,
bool
use_logit_softcap
>
template
<
int
D
,
int
cols_per_block
,
ggml_type
type_K
,
ggml_type
type_V
,
bool
use_logit_softcap
>
void
ggml_cuda_flash_attn_ext_vec_f32_case_impl
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
constexpr
int
nwarps
=
D
/
WARP_SIZE
;
fattn_kernel_t
fattn_kernel
=
flash_attn_vec_ext_f32
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
;
fattn_kernel_t
fattn_kernel
=
flash_attn_vec_ext_f32
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
;
constexpr
bool
need_f16_K
=
D
!=
128
;
constexpr
bool
need_f16_V
=
D
!=
128
&&
D
!=
64
;
constexpr
size_t
nbytes_shared
=
0
;
launch_fattn
<
D
,
cols_per_block
,
1
,
parallel_blocks
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared
,
need_f16_K
,
need_f16_V
);
launch_fattn
<
D
,
cols_per_block
,
1
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared
,
D
,
need_f16_K
,
need_f16_V
,
false
);
}
template
<
int
D
,
ggml_type
type_K
,
ggml_type
type_V
>
...
...
@@ -307,65 +327,48 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
memcpy
(
&
logit_softcap
,
(
const
float
*
)
KQV
->
op_params
+
2
,
sizeof
(
float
));
if
(
Q
->
ne
[
1
]
==
1
)
{
constexpr
int
cols_per_block
=
1
;
constexpr
int
parallel_blocks
=
4
;
constexpr
int
cols_per_block
=
1
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
return
;
}
if
(
Q
->
ne
[
1
]
==
2
)
{
constexpr
int
cols_per_block
=
2
;
constexpr
int
parallel_blocks
=
4
;
constexpr
int
cols_per_block
=
2
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
return
;
}
if
(
Q
->
ne
[
1
]
<=
4
)
{
constexpr
int
cols_per_block
=
4
;
constexpr
int
parallel_blocks
=
4
;
constexpr
int
cols_per_block
=
4
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
return
;
}
if
(
Q
->
ne
[
1
]
<=
8
)
{
constexpr
int
cols_per_block
=
8
;
constexpr
int
parallel_blocks
=
4
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
return
;
}
constexpr
int
cols_per_block
=
8
;
constexpr
int
parallel_blocks
=
1
;
constexpr
int
cols_per_block
=
8
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
parallel_blocks
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
ggml_cuda_flash_attn_ext_vec_f32_case_impl
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
(
ctx
,
dst
);
}
}
...
...
ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu
View file @
943464cc
...
...
@@ -7,14 +7,19 @@
#include "fattn-wmma-f16.cuh"
#ifdef FP16_MMA_AVAILABLE
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#include <mma.h>
namespace
wmma
=
nvcuda
::
wmma
;
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
#include <rocwmma/rocwmma.hpp>
namespace
wmma
=
rocwmma
;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#endif // FP16_MMA_AVAILABLE
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template
<
int
D
,
int
ncols
,
int
nwarps
,
int
VKQ_stride
,
int
parallel_blocks
,
typename
KQ_acc_t
,
bool
use_logit_softcap
>
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__
(
nwarps
*
WARP_SIZE
,
1
)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
template
<
int
D
,
int
ncols
,
int
nwarps
,
int
VKQ_stride
,
typename
KQ_acc_t
,
bool
use_logit_softcap
>
__launch_bounds__
(
nwarps
*
ggml_cuda_get_physical_warp_size
(),
1
)
static
__global__
void
flash_attn_ext_f16
(
const
char
*
__restrict__
Q
,
const
char
*
__restrict__
K
,
...
...
@@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16(
const
int
ne1
,
const
int
ne2
,
const
int
ne3
)
{
#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#if defined(FLASH_ATTN_AVAILABLE) &&
(
__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
|| (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
// Skip unused kernel variants for faster compilation:
if
(
use_logit_softcap
&&
!
(
D
==
128
||
D
==
256
))
{
NO_DEVICE_CODE
;
...
...
@@ -60,19 +65,20 @@ static __global__ void flash_attn_ext_f16(
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const
int
ic0
=
ncols
*
(
blockIdx
.
x
/
parallel_blocks
);
// Index of the first Q/QKV column to work on.
const
int
ip
=
blockIdx
.
x
%
parallel_blocks
;
// Index in group of blocks running for the same column in parallel.
constexpr
int
warp_size
=
ggml_cuda_get_physical_warp_size
();
const
int
ic0
=
ncols
*
blockIdx
.
x
;
// Index of the first Q/QKV column to work on.
static_assert
(
D
<=
FATTN_KQ_STRIDE
,
"D must be <= FATTN_KQ_STRIDE."
);
static_assert
(
ncols
==
8
||
ncols
%
16
==
0
,
"ncols must be 8 or a multiple of 16."
);
constexpr
int
frag_m
=
ncols
==
8
?
32
:
16
;
constexpr
int
frag_n
=
ncols
==
8
?
8
:
16
;
static_assert
(
D
%
frag_m
==
0
,
"If ncols == 8 then D % frag_m must be 0."
);
typedef
nvcuda
::
wmma
::
fragment
<
nvcuda
::
wmma
::
matrix_a
,
frag_m
,
frag_n
,
16
,
half
,
nvcuda
::
wmma
::
row_major
>
frag_a_K
;
typedef
nvcuda
::
wmma
::
fragment
<
nvcuda
::
wmma
::
matrix_a
,
frag_m
,
frag_n
,
16
,
half
,
nvcuda
::
wmma
::
col_major
>
frag_a_V
;
typedef
nvcuda
::
wmma
::
fragment
<
nvcuda
::
wmma
::
matrix_b
,
frag_m
,
frag_n
,
16
,
half
,
nvcuda
::
wmma
::
col_major
>
frag_b
;
typedef
nvcuda
::
wmma
::
fragment
<
nvcuda
::
wmma
::
accumulator
,
frag_m
,
frag_n
,
16
,
KQ_acc_t
>
frag_c_KQ
;
typedef
nvcuda
::
wmma
::
fragment
<
nvcuda
::
wmma
::
accumulator
,
frag_m
,
frag_n
,
16
,
half
>
frag_c_VKQ
;
typedef
wmma
::
fragment
<
wmma
::
matrix_a
,
frag_m
,
frag_n
,
16
,
half
,
wmma
::
row_major
>
frag_a_K
;
typedef
wmma
::
fragment
<
wmma
::
matrix_a
,
frag_m
,
frag_n
,
16
,
half
,
wmma
::
col_major
>
frag_a_V
;
typedef
wmma
::
fragment
<
wmma
::
matrix_b
,
frag_m
,
frag_n
,
16
,
half
,
wmma
::
col_major
>
frag_b
;
typedef
wmma
::
fragment
<
wmma
::
accumulator
,
frag_m
,
frag_n
,
16
,
KQ_acc_t
>
frag_c_KQ
;
typedef
wmma
::
fragment
<
wmma
::
accumulator
,
frag_m
,
frag_n
,
16
,
half
>
frag_c_VKQ
;
constexpr
int
KQ_stride_tc
=
nwarps
*
frag_m
;
// Number of KQ rows calculated in parallel.
constexpr
int
VKQ_ratio
=
KQ_stride_tc
/
VKQ_stride
;
// Number of parallel VKQ accumulators needed to keep all warps busy.
...
...
@@ -84,16 +90,16 @@ static __global__ void flash_attn_ext_f16(
constexpr
int
kqar
=
sizeof
(
KQ_acc_t
)
/
sizeof
(
half
);
const
int
gqa_ratio
=
ne02
/
ne12
;
// With grouped query attention there are > 1 Q matrices per K, V matrix.
const
float
*
Q_f
=
(
const
float
*
)
(
Q
+
nb02
*
blockIdx
.
y
+
nb01
*
ic0
);
const
half
*
K_h
=
(
const
half
*
)
(
K
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
const
half
*
V_h
=
(
const
half
*
)
(
V
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
// K and V have same shape
const
float
*
Q_f
=
(
const
float
*
)
(
Q
+
nb02
*
blockIdx
.
z
+
nb01
*
ic0
);
const
half
*
K_h
=
(
const
half
*
)
(
K
+
nb12
*
(
blockIdx
.
z
/
gqa_ratio
));
const
half
*
V_h
=
(
const
half
*
)
(
V
+
nb12
*
(
blockIdx
.
z
/
gqa_ratio
));
// K and V have same shape
const
half
*
maskh
=
(
const
half
*
)
mask
+
(
nb31
/
sizeof
(
half
))
*
ic0
;
const
half2
*
mask2
=
(
const
half2
*
)
mask
+
(
nb31
/
sizeof
(
half
))
*
(
ic0
/
2
);
const
int
stride_Q
=
nb01
/
sizeof
(
float
);
const
int
stride_KV
=
nb11
/
sizeof
(
half
);
const
float
slopef
=
get_alibi_slope
(
max_bias
,
blockIdx
.
y
,
n_head_log2
,
m0
,
m1
);
const
float
slopef
=
get_alibi_slope
(
max_bias
,
blockIdx
.
z
,
n_head_log2
,
m0
,
m1
);
const
half
slopeh
=
__float2half
(
slopef
);
const
half2
slope2
=
make_half2
(
slopef
,
slopef
);
...
...
@@ -132,9 +138,9 @@ static __global__ void flash_attn_ext_f16(
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
const
int
j
=
j0
+
threadIdx
.
y
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
warp_size
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
if
(
i0
+
WARP_SIZE
>
D
/
2
&&
i
>=
D
/
2
)
{
if
(
i0
+
warp_size
>
D
/
2
&&
i
>=
D
/
2
)
{
break
;
}
VKQ2
[
j
*
(
D_padded
/
2
)
+
i
]
=
make_half2
(
0.0
f
,
0.0
f
);
...
...
@@ -146,9 +152,9 @@ static __global__ void flash_attn_ext_f16(
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
const
int
j
=
j0
+
threadIdx
.
y
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
;
i0
+=
WARP_SIZE
)
{
for
(
int
i0
=
0
;
i0
<
D
;
i0
+=
warp_size
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
if
(
i0
+
WARP_SIZE
>
D
&&
i
>=
D
)
{
if
(
i0
+
warp_size
>
D
&&
i
>=
D
)
{
break
;
}
KQ
[
j
*
D_padded
+
i
]
=
ic0
+
j
<
ne01
?
Q_f
[
j
*
stride_Q
+
i
]
*
scale
:
0.0
f
;
...
...
@@ -162,34 +168,34 @@ static __global__ void flash_attn_ext_f16(
for
(
int
i0
=
0
;
i0
<
D
;
i0
+=
16
)
{
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
frag_n
)
{
nvcuda
::
wmma
::
load_matrix_sync
(
Q_b
[
i0
/
16
][
j0
/
frag_n
],
KQ
+
j0
*
D_padded
+
i0
,
D_padded
);
wmma
::
load_matrix_sync
(
Q_b
[
i0
/
16
][
j0
/
frag_n
],
KQ
+
j0
*
D_padded
+
i0
,
D_padded
);
}
}
__syncthreads
();
// Iterate over ne11 == previous tokens:
for
(
int
k_VKQ_0
=
ip
*
FATTN_KQ_STRIDE
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
parallel_blocks
*
FATTN_KQ_STRIDE
)
{
for
(
int
k_VKQ_0
=
blockIdx
.
y
*
FATTN_KQ_STRIDE
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
gridDim
.
y
*
FATTN_KQ_STRIDE
)
{
// Calculate tile of KQ:
#pragma unroll
for
(
int
i_KQ_0
=
0
;
i_KQ_0
<
FATTN_KQ_STRIDE
;
i_KQ_0
+=
KQ_stride_tc
)
{
frag_c_KQ
KQ_c
[
ncols
/
frag_n
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
/
frag_n
;
++
j
)
{
nvcuda
::
wmma
::
fill_fragment
(
KQ_c
[
j
],
0.0
f
);
wmma
::
fill_fragment
(
KQ_c
[
j
],
static_cast
<
KQ_acc_t
>
(
0.0
f
)
)
;
}
#pragma unroll
for
(
int
k_KQ_0
=
0
;
k_KQ_0
<
D
;
k_KQ_0
+=
16
)
{
frag_a_K
K_a
;
nvcuda
::
wmma
::
load_matrix_sync
(
K_a
,
K_h
+
(
k_VKQ_0
+
i_KQ_0
+
frag_m
*
threadIdx
.
y
)
*
stride_KV
+
k_KQ_0
,
stride_KV
);
wmma
::
load_matrix_sync
(
K_a
,
K_h
+
(
k_VKQ_0
+
i_KQ_0
+
frag_m
*
threadIdx
.
y
)
*
stride_KV
+
k_KQ_0
,
stride_KV
);
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
/
frag_n
;
++
j
)
{
nvcuda
::
wmma
::
mma_sync
(
KQ_c
[
j
],
K_a
,
Q_b
[
k_KQ_0
/
16
][
j
],
KQ_c
[
j
]);
wmma
::
mma_sync
(
KQ_c
[
j
],
K_a
,
Q_b
[
k_KQ_0
/
16
][
j
],
KQ_c
[
j
]);
}
}
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
frag_n
)
{
nvcuda
::
wmma
::
store_matrix_sync
((
KQ_acc_t
*
)
KQ
+
j0
*
kqs_padded
+
i_KQ_0
+
frag_m
*
threadIdx
.
y
,
KQ_c
[
j0
/
frag_n
],
kqs_padded
,
nvcuda
::
wmma
::
mem_col_major
);
wmma
::
store_matrix_sync
((
KQ_acc_t
*
)
KQ
+
j0
*
kqs_padded
+
i_KQ_0
+
frag_m
*
threadIdx
.
y
,
KQ_c
[
j0
/
frag_n
],
kqs_padded
,
wmma
::
mem_col_major
);
}
}
...
...
@@ -202,27 +208,27 @@ static __global__ void flash_attn_ext_f16(
const
int
j
=
j0
+
threadIdx
.
y
;
if
(
std
::
is_same
<
KQ_acc_t
,
float
>::
value
)
{
float
KQ_f_tmp
[
FATTN_KQ_STRIDE
/
WARP_SIZE
];
float
KQ_f_tmp
[
FATTN_KQ_STRIDE
/
warp_size
];
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
;
k0
+=
WARP_SIZE
)
{
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
;
k0
+=
warp_size
)
{
const
int
k
=
k0
+
threadIdx
.
x
;
KQ_f_tmp
[
k0
/
WARP_SIZE
]
=
KQ_f
[
j
*
kqs_padded
+
k
];
KQ_f_tmp
[
k0
/
warp_size
]
=
KQ_f
[
j
*
kqs_padded
+
k
];
if
(
use_logit_softcap
)
{
KQ_f_tmp
[
k0
/
WARP_SIZE
]
=
logit_softcap
*
tanhf
(
KQ_f_tmp
[
k0
/
WARP_SIZE
]);
KQ_f_tmp
[
k0
/
warp_size
]
=
logit_softcap
*
tanhf
(
KQ_f_tmp
[
k0
/
warp_size
]);
}
}
float
KQ_max_new
=
KQ_max_f
[
j0
/
nwarps
];
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
;
k0
+=
WARP_SIZE
)
{
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
;
k0
+=
warp_size
)
{
const
int
k
=
k0
+
threadIdx
.
x
;
KQ_f_tmp
[
k0
/
WARP_SIZE
]
+=
mask
?
__half2float
(
slopeh
*
maskh
[
j
*
(
nb31
/
sizeof
(
half
))
+
k_VKQ_0
+
k
])
:
0.0
f
;
KQ_max_new
=
max
(
KQ_max_new
,
KQ_f_tmp
[
k0
/
WARP_SIZE
]);
KQ_f_tmp
[
k0
/
warp_size
]
+=
mask
?
__half2float
(
slopeh
*
maskh
[
j
*
(
nb31
/
sizeof
(
half
))
+
k_VKQ_0
+
k
])
:
0.0
f
;
KQ_max_new
=
max
(
KQ_max_new
,
KQ_f_tmp
[
k0
/
warp_size
]);
}
KQ_max_new
=
warp_reduce_max
(
KQ_max_new
);
KQ_max_new
=
warp_reduce_max
<
warp_size
>
(
KQ_max_new
);
const
float
diff
=
KQ_max_f
[
j0
/
nwarps
]
-
KQ_max_new
;
KQ_max_scale_f
[
j0
/
nwarps
]
=
expf
(
diff
);
...
...
@@ -233,48 +239,48 @@ static __global__ void flash_attn_ext_f16(
float
KQ_rowsum_add
=
0.0
f
;
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
;
k0
+=
WARP_SIZE
)
{
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
;
k0
+=
warp_size
)
{
const
int
k
=
k0
+
threadIdx
.
x
;
const
float
diff
=
KQ_f_tmp
[
k0
/
WARP_SIZE
]
-
KQ_max_f
[
j0
/
nwarps
];
KQ_f_tmp
[
k0
/
WARP_SIZE
]
=
expf
(
diff
);
const
float
diff
=
KQ_f_tmp
[
k0
/
warp_size
]
-
KQ_max_f
[
j0
/
nwarps
];
KQ_f_tmp
[
k0
/
warp_size
]
=
expf
(
diff
);
if
(
diff
<=
SOFTMAX_FTZ_THRESHOLD
)
{
KQ_f_tmp
[
k0
/
WARP_SIZE
]
=
0.0
f
;
KQ_f_tmp
[
k0
/
warp_size
]
=
0.0
f
;
}
KQ_rowsum_add
+=
KQ_f_tmp
[
k0
/
WARP_SIZE
];
KQ
[
j
*
(
kqar
*
kqs_padded
)
+
k
]
=
KQ_f_tmp
[
k0
/
WARP_SIZE
];
KQ_rowsum_add
+=
KQ_f_tmp
[
k0
/
warp_size
];
KQ
[
j
*
(
kqar
*
kqs_padded
)
+
k
]
=
KQ_f_tmp
[
k0
/
warp_size
];
}
KQ_rowsum_add
=
warp_reduce_sum
(
KQ_rowsum_add
);
KQ_rowsum_add
=
warp_reduce_sum
<
warp_size
>
(
KQ_rowsum_add
);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_f
[
j0
/
nwarps
]
=
KQ_max_scale_f
[
j0
/
nwarps
]
*
KQ_rowsum_f
[
j0
/
nwarps
]
+
KQ_rowsum_add
;
}
else
{
half2
KQ2_tmp
[
FATTN_KQ_STRIDE
/
(
2
*
WARP_SIZE
)];
half2
KQ2_tmp
[
FATTN_KQ_STRIDE
/
(
2
*
warp_size
)];
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
/
2
;
k0
+=
WARP_SIZE
)
{
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
/
2
;
k0
+=
warp_size
)
{
const
int
k
=
k0
+
threadIdx
.
x
;
KQ2_tmp
[
k0
/
WARP_SIZE
]
=
KQ2
[
j
*
(
kqs_padded
/
2
)
+
k
];
KQ2_tmp
[
k0
/
warp_size
]
=
KQ2
[
j
*
(
kqs_padded
/
2
)
+
k
];
if
(
use_logit_softcap
)
{
// There is no dedicated tangens hyperbolicus function for half2.
KQ2_tmp
[
k0
/
WARP_SIZE
]
=
h2exp
(
KQ2_tmp
[
k0
/
WARP_SIZE
]
*
make_half2
(
2.0
f
,
2.0
f
));
KQ2_tmp
[
k0
/
WARP_SIZE
]
=
(
KQ2_tmp
[
k0
/
WARP_SIZE
]
-
make_half2
(
1.0
f
,
1.0
f
))
/
(
KQ2_tmp
[
k0
/
WARP_SIZE
]
+
make_half2
(
1.0
f
,
1.0
f
));
KQ2_tmp
[
k0
/
warp_size
]
=
h2exp
(
KQ2_tmp
[
k0
/
warp_size
]
*
make_half2
(
2.0
f
,
2.0
f
));
KQ2_tmp
[
k0
/
warp_size
]
=
(
KQ2_tmp
[
k0
/
warp_size
]
-
make_half2
(
1.0
f
,
1.0
f
))
/
(
KQ2_tmp
[
k0
/
warp_size
]
+
make_half2
(
1.0
f
,
1.0
f
));
KQ2_tmp
[
k0
/
WARP_SIZE
]
*=
logit_softcap_2
;
KQ2_tmp
[
k0
/
warp_size
]
*=
logit_softcap_2
;
}
}
half2
KQ_max_new
=
KQ_max_h2
[
j0
/
nwarps
];
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
/
2
;
k0
+=
WARP_SIZE
)
{
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
/
2
;
k0
+=
warp_size
)
{
const
int
k
=
k0
+
threadIdx
.
x
;
KQ2_tmp
[
k0
/
WARP_SIZE
]
+=
mask
?
slope2
*
mask2
[(
j
*
ne11
+
k_VKQ_0
)
/
2
+
k
]
:
make_half2
(
0.0
f
,
0.0
f
);
KQ_max_new
=
ggml_cuda_hmax2
(
KQ_max_new
,
KQ2_tmp
[
k0
/
WARP_SIZE
]);
KQ2_tmp
[
k0
/
warp_size
]
+=
mask
?
slope2
*
mask2
[(
j
*
ne11
+
k_VKQ_0
)
/
2
+
k
]
:
make_half2
(
0.0
f
,
0.0
f
);
KQ_max_new
=
ggml_cuda_hmax2
(
KQ_max_new
,
KQ2_tmp
[
k0
/
warp_size
]);
}
KQ_max_new
=
__half2half2
(
warp_reduce_max
(
ggml_cuda_hmax
(
__low2half
(
KQ_max_new
),
__high2half
(
KQ_max_new
))));
KQ_max_new
=
__half2half2
(
warp_reduce_max
<
warp_size
>
(
ggml_cuda_hmax
(
__low2half
(
KQ_max_new
),
__high2half
(
KQ_max_new
))));
const
half2
diff
=
KQ_max_h2
[
j0
/
nwarps
]
-
KQ_max_new
;
KQ_max_scale_h2
[
j0
/
nwarps
]
=
h2exp
(
diff
);
const
uint32_t
ftz_mask
=
__hgt2_mask
(
diff
,
make_half2
(
SOFTMAX_FTZ_THRESHOLD
,
SOFTMAX_FTZ_THRESHOLD
));
...
...
@@ -283,17 +289,17 @@ static __global__ void flash_attn_ext_f16(
half2
KQ_rowsum_add
=
make_half2
(
0.0
f
,
0.0
f
);
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
/
2
;
k0
+=
WARP_SIZE
)
{
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
/
2
;
k0
+=
warp_size
)
{
const
int
k
=
k0
+
threadIdx
.
x
;
const
half2
diff
=
KQ2_tmp
[
k0
/
WARP_SIZE
]
-
KQ_max_h2
[
j0
/
nwarps
];
KQ2_tmp
[
k0
/
WARP_SIZE
]
=
h2exp
(
diff
);
const
half2
diff
=
KQ2_tmp
[
k0
/
warp_size
]
-
KQ_max_h2
[
j0
/
nwarps
];
KQ2_tmp
[
k0
/
warp_size
]
=
h2exp
(
diff
);
const
uint32_t
ftz_mask
=
__hgt2_mask
(
diff
,
make_half2
(
SOFTMAX_FTZ_THRESHOLD
,
SOFTMAX_FTZ_THRESHOLD
));
*
((
uint32_t
*
)
&
KQ2_tmp
[
k0
/
WARP_SIZE
])
&=
ftz_mask
;
KQ_rowsum_add
+=
KQ2_tmp
[
k0
/
WARP_SIZE
];
KQ2
[
j
*
(
kqs_padded
/
2
)
+
k
]
=
KQ2_tmp
[
k0
/
WARP_SIZE
];
*
((
uint32_t
*
)
&
KQ2_tmp
[
k0
/
warp_size
])
&=
ftz_mask
;
KQ_rowsum_add
+=
KQ2_tmp
[
k0
/
warp_size
];
KQ2
[
j
*
(
kqs_padded
/
2
)
+
k
]
=
KQ2_tmp
[
k0
/
warp_size
];
}
KQ_rowsum_add
=
warp_reduce_sum
(
KQ_rowsum_add
);
KQ_rowsum_add
=
warp_reduce_sum
<
warp_size
>
(
KQ_rowsum_add
);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_h2
[
j0
/
nwarps
]
=
KQ_max_scale_h2
[
j0
/
nwarps
]
*
KQ_rowsum_h2
[
j0
/
nwarps
]
+
KQ_rowsum_add
;
...
...
@@ -308,7 +314,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
;
k0
+=
VKQ_ratio
*
16
)
{
const
int
k
=
k0
+
(
threadIdx
.
y
%
VKQ_ratio
)
*
16
;
nvcuda
::
wmma
::
load_matrix_sync
(
wmma
::
load_matrix_sync
(
KQ_b
[
k0
/
(
VKQ_ratio
*
16
)][
j0
/
frag_n
],
KQ
+
j0
*
(
kqar
*
kqs_padded
)
+
k
,
kqar
*
kqs_padded
);
...
...
@@ -320,7 +326,7 @@ static __global__ void flash_attn_ext_f16(
for
(
int
i_VKQ_0
=
0
;
i_VKQ_0
<
D
;
i_VKQ_0
+=
VKQ_stride
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
/
frag_n
;
++
j
)
{
nvcuda
::
wmma
::
fill_fragment
(
VKQ_c
[
i_VKQ_0
/
VKQ_stride
][
j
],
0.0
f
);
wmma
::
fill_fragment
(
VKQ_c
[
i_VKQ_0
/
VKQ_stride
][
j
],
static_cast
<
half
>
(
0.0
f
)
)
;
}
#pragma unroll
...
...
@@ -328,10 +334,10 @@ static __global__ void flash_attn_ext_f16(
const
int
k
=
k0
+
(
threadIdx
.
y
%
VKQ_ratio
)
*
16
;
frag_a_V
v_a
;
nvcuda
::
wmma
::
load_matrix_sync
(
v_a
,
V_h
+
(
k_VKQ_0
+
k
)
*
stride_KV
+
i_VKQ_0
+
frag_m
*
(
threadIdx
.
y
/
VKQ_ratio
),
stride_KV
);
wmma
::
load_matrix_sync
(
v_a
,
V_h
+
(
k_VKQ_0
+
k
)
*
stride_KV
+
i_VKQ_0
+
frag_m
*
(
threadIdx
.
y
/
VKQ_ratio
),
stride_KV
);
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
/
frag_n
;
++
j
)
{
nvcuda
::
wmma
::
mma_sync
(
VKQ_c
[
i_VKQ_0
/
VKQ_stride
][
j
],
v_a
,
KQ_b
[
k0
/
(
VKQ_ratio
*
16
)][
j
],
VKQ_c
[
i_VKQ_0
/
VKQ_stride
][
j
]);
wmma
::
mma_sync
(
VKQ_c
[
i_VKQ_0
/
VKQ_stride
][
j
],
v_a
,
KQ_b
[
k0
/
(
VKQ_ratio
*
16
)][
j
],
VKQ_c
[
i_VKQ_0
/
VKQ_stride
][
j
]);
}
}
}
...
...
@@ -343,10 +349,10 @@ static __global__ void flash_attn_ext_f16(
for
(
int
i_KQ_0
=
0
;
i_KQ_0
<
D
;
i_KQ_0
+=
VKQ_stride
)
{
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
frag_n
)
{
nvcuda
::
wmma
::
store_matrix_sync
(
wmma
::
store_matrix_sync
(
KQ
+
offset_k
+
j0
*
D_padded
+
i_KQ_0
+
frag_m
*
(
threadIdx
.
y
/
VKQ_ratio
),
VKQ_c
[
i_KQ_0
/
VKQ_stride
][
j0
/
frag_n
],
D_padded
,
nvcuda
::
wmma
::
mem_col_major
);
D_padded
,
wmma
::
mem_col_major
);
}
}
...
...
@@ -364,9 +370,9 @@ static __global__ void flash_attn_ext_f16(
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
warp_size
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
if
(
i0
+
WARP_SIZE
>
D
/
2
&&
i
>=
D
/
2
)
{
if
(
i0
+
warp_size
>
D
/
2
&&
i
>=
D
/
2
)
{
break
;
}
...
...
@@ -388,7 +394,7 @@ static __global__ void flash_attn_ext_f16(
if
(
ic0
+
j_VKQ
>=
ne01
)
{
return
;
}
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
parallel_blocks
+
ip
;
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
gridDim
.
y
+
blockIdx
.
y
;
float
KQ_rowsum_j
;
if
(
std
::
is_same
<
KQ_acc_t
,
float
>::
value
)
{
...
...
@@ -398,19 +404,19 @@ static __global__ void flash_attn_ext_f16(
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
;
i0
+=
WARP_SIZE
)
{
for
(
int
i0
=
0
;
i0
<
D
;
i0
+=
warp_size
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
if
(
i0
+
WARP_SIZE
>
D
&&
i
>=
D
)
{
if
(
i0
+
warp_size
>
D
&&
i
>=
D
)
{
break
;
}
float
dst_val
=
VKQ
[
j_VKQ
*
D_padded
+
i
];
if
(
parallel_blocks
==
1
)
{
if
(
gridDim
.
y
==
1
)
{
dst_val
/=
KQ_rowsum_j
;
}
dst
[
j_dst
*
gridDim
.
y
*
D
+
blockIdx
.
y
*
D
+
i
]
=
dst_val
;
dst
[
j_dst
*
gridDim
.
z
*
D
+
blockIdx
.
z
*
D
+
i
]
=
dst_val
;
}
if
(
parallel_blocks
==
1
||
threadIdx
.
x
!=
0
)
{
if
(
gridDim
.
y
==
1
||
threadIdx
.
x
!=
0
)
{
continue
;
}
...
...
@@ -421,11 +427,21 @@ static __global__ void flash_attn_ext_f16(
dst_meta_val
.
x
=
__low2float
(
KQ_max_h2
[
j0
/
nwarps
]);
}
dst_meta_val
.
y
=
KQ_rowsum_j
;
dst_meta
[(
ic0
+
j_VKQ
)
*
gridDim
.
y
*
parallel_blocks
+
blockIdx
.
y
*
parallel_blocks
+
ip
]
=
dst_meta_val
;
dst_meta
[(
(
ic0
+
j_VKQ
)
*
gridDim
.
z
+
blockIdx
.
z
)
*
gridDim
.
y
+
blockIdx
.
y
]
=
dst_meta_val
;
}
#else
NO_DEVICE_CODE
;
#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
GGML_UNUSED
(
Q
);
GGML_UNUSED
(
K
);
GGML_UNUSED
(
V
);
GGML_UNUSED
(
mask
);
GGML_UNUSED
(
dst
);
GGML_UNUSED
(
dst_meta
);
GGML_UNUSED
(
scale
);
GGML_UNUSED
(
max_bias
);
GGML_UNUSED
(
m0
);
GGML_UNUSED
(
m1
);
GGML_UNUSED
(
n_head_log2
);
GGML_UNUSED
(
logit_softcap
);
GGML_UNUSED
(
ne00
);
GGML_UNUSED
(
ne01
);
GGML_UNUSED
(
ne02
);
GGML_UNUSED
(
ne03
);
GGML_UNUSED
(
ne10
);
GGML_UNUSED
(
ne11
);
GGML_UNUSED
(
ne12
);
GGML_UNUSED
(
ne13
);
GGML_UNUSED
(
ne31
);
GGML_UNUSED
(
nb31
);
GGML_UNUSED
(
nb01
);
GGML_UNUSED
(
nb02
);
GGML_UNUSED
(
nb03
);
GGML_UNUSED
(
nb11
);
GGML_UNUSED
(
nb12
);
GGML_UNUSED
(
nb13
);
GGML_UNUSED
(
nb21
);
GGML_UNUSED
(
nb22
);
GGML_UNUSED
(
nb23
);
GGML_UNUSED
(
ne0
);
GGML_UNUSED
(
ne1
);
GGML_UNUSED
(
ne2
);
GGML_UNUSED
(
ne3
);
NO_DEVICE_CODE
;
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
}
constexpr
int
get_max_power_of_2
(
int
x
)
{
...
...
@@ -455,59 +471,26 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
template
<
int
D
,
int
cols_per_block
,
typename
KQ_acc_t
>
void
ggml_cuda_flash_attn_ext_wmma_f16_case
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
KQV
=
dst
;
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
constexpr
int
nwarps
=
4
;
constexpr
int
frag_m
=
cols_per_block
==
8
&&
D
%
32
==
0
?
32
:
16
;
const
int
blocks_num_pb1
=
((
Q
->
ne
[
1
]
+
cols_per_block
-
1
)
/
cols_per_block
)
*
Q
->
ne
[
2
]
*
Q
->
ne
[
3
];
const
int
nsm
=
ggml_cuda_info
().
devices
[
ggml_cuda_get_device
()].
nsm
;
const
int
warp_size
=
ggml_cuda_info
().
devices
[
ggml_cuda_get_device
()].
warp_size
;
float
logit_softcap
;
memcpy
(
&
logit_softcap
,
(
const
float
*
)
KQV
->
op_params
+
2
,
sizeof
(
float
));
if
(
4
*
blocks_num_pb1
<
2
*
nsm
)
{
constexpr
int
parallel_blocks
=
4
;
fattn_kernel_t
fattn_kernel
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
fattn_kernel
=
flash_attn_ext_f16
<
D
,
cols_per_block
,
nwarps
,
get_VKQ_stride
(
D
,
nwarps
,
frag_m
),
parallel_blocks
,
KQ_acc_t
,
use_logit_softcap
>
;
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
fattn_kernel
=
flash_attn_ext_f16
<
D
,
cols_per_block
,
nwarps
,
get_VKQ_stride
(
D
,
nwarps
,
frag_m
),
parallel_blocks
,
KQ_acc_t
,
use_logit_softcap
>
;
}
launch_fattn
<
D
,
cols_per_block
,
1
,
parallel_blocks
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
0
,
true
,
true
);
return
;
}
if
(
2
*
blocks_num_pb1
<
2
*
nsm
)
{
constexpr
int
parallel_blocks
=
2
;
fattn_kernel_t
fattn_kernel
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
fattn_kernel
=
flash_attn_ext_f16
<
D
,
cols_per_block
,
nwarps
,
get_VKQ_stride
(
D
,
nwarps
,
frag_m
),
parallel_blocks
,
KQ_acc_t
,
use_logit_softcap
>
;
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
fattn_kernel
=
flash_attn_ext_f16
<
D
,
cols_per_block
,
nwarps
,
get_VKQ_stride
(
D
,
nwarps
,
frag_m
),
parallel_blocks
,
KQ_acc_t
,
use_logit_softcap
>
;
}
launch_fattn
<
D
,
cols_per_block
,
1
,
parallel_blocks
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
0
,
true
,
true
);
return
;
}
constexpr
int
parallel_blocks
=
1
;
fattn_kernel_t
fattn_kernel
;
if
(
logit_softcap
==
0.0
f
)
{
constexpr
bool
use_logit_softcap
=
false
;
fattn_kernel
=
flash_attn_ext_f16
<
D
,
cols_per_block
,
nwarps
,
get_VKQ_stride
(
D
,
nwarps
,
frag_m
),
parallel_blocks
,
KQ_acc_t
,
use_logit_softcap
>
;
D
,
cols_per_block
,
nwarps
,
get_VKQ_stride
(
D
,
nwarps
,
frag_m
),
KQ_acc_t
,
use_logit_softcap
>
;
}
else
{
constexpr
bool
use_logit_softcap
=
true
;
fattn_kernel
=
flash_attn_ext_f16
<
D
,
cols_per_block
,
nwarps
,
get_VKQ_stride
(
D
,
nwarps
,
frag_m
),
parallel_blocks
,
KQ_acc_t
,
use_logit_softcap
>
;
D
,
cols_per_block
,
nwarps
,
get_VKQ_stride
(
D
,
nwarps
,
frag_m
),
KQ_acc_t
,
use_logit_softcap
>
;
}
launch_fattn
<
D
,
cols_per_block
,
1
,
parallel_blocks
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
0
,
true
,
tru
e
);
launch_fattn
<
D
,
cols_per_block
,
1
,
-
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
0
,
FATTN_KQ_STRIDE
,
true
,
true
,
false
,
warp_siz
e
);
}
void
ggml_cuda_flash_attn_ext_wmma_f16
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
...
...
@@ -515,6 +498,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
const
enum
ggml_prec
prec
=
ggml_flash_attn_ext_get_prec
(
KQV
);
const
int
warp_size
=
ggml_cuda_info
().
devices
[
ctx
.
device
].
warp_size
;
if
(
prec
!=
GGML_PREC_DEFAULT
)
{
if
(
Q
->
ne
[
1
]
<=
32
||
Q
->
ne
[
0
]
>
128
)
{
...
...
@@ -571,7 +555,8 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
return
;
}
if
(
Q
->
ne
[
1
]
<=
8
&&
Q
->
ne
[
0
]
%
WARP_SIZE
==
0
)
{
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
if
(
Q
->
ne
[
1
]
<=
8
&&
Q
->
ne
[
0
]
%
warp_size
==
0
)
{
constexpr
int
cols_per_block
=
8
;
switch
(
Q
->
ne
[
0
])
{
case
64
:
...
...
@@ -592,6 +577,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
}
return
;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
if
(
Q
->
ne
[
1
]
<=
32
)
{
constexpr
int
cols_per_block
=
16
;
...
...
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
View file @
943464cc
...
...
@@ -250,10 +250,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
ggml_cuda_set_device
(
ctx
.
device
);
const
int
cc
=
ggml_cuda_info
().
devices
[
ggml_cuda_get_device
()].
cc
;
const
int
warp_size
=
ggml_cuda_info
().
devices
[
ggml_cuda_get_device
()].
warp_size
;
const
enum
ggml_prec
prec
=
ggml_flash_attn_ext_get_prec
(
KQV
);
// On AMD the tile kernels perform poorly, use the vec kernel instead:
if
(
cc
>=
GGML_CUDA_CC_OFFSET_AMD
)
{
if
(
GGML_CUDA_CC_IS_AMD
(
cc
))
{
#if defined(GGML_HIP_ROCWMMA_FATTN)
if
(
fp16_mma_available
(
cc
))
{
ggml_cuda_flash_attn_ext_wmma_f16
(
ctx
,
dst
);
return
;
}
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
// On AMD the tile kernels perform poorly, use the vec kernel instead:
if
(
prec
==
GGML_PREC_DEFAULT
&&
fast_fp16_available
(
cc
))
{
ggml_cuda_flash_attn_ext_vec_f16
(
ctx
,
dst
);
}
else
{
...
...
@@ -273,13 +281,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
if
(
!
fp16_mma_available
(
cc
))
{
if
(
prec
==
GGML_PREC_DEFAULT
)
{
if
(
Q
->
ne
[
1
]
<=
8
)
{
if
(
Q
->
ne
[
1
]
<=
8
||
Q
->
ne
[
0
]
==
256
)
{
ggml_cuda_flash_attn_ext_vec_f16
(
ctx
,
dst
);
}
else
{
ggml_cuda_flash_attn_ext_tile_f16
(
ctx
,
dst
);
}
}
else
{
if
(
Q
->
ne
[
1
]
<=
8
)
{
if
(
Q
->
ne
[
1
]
<=
8
||
Q
->
ne
[
0
]
==
256
)
{
ggml_cuda_flash_attn_ext_vec_f32
(
ctx
,
dst
);
}
else
{
ggml_cuda_flash_attn_ext_tile_f32
(
ctx
,
dst
);
...
...
@@ -288,21 +296,21 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return
;
}
const
int
gqa_ratio
=
Q
->
ne
[
2
]
/
K
->
ne
[
2
];
const
bool
mma_fast_for_bs1
=
fp16_mma_available
(
cc
)
&&
gqa_ratio
%
2
==
0
&&
K
->
type
==
GGML_TYPE_F16
&&
V
->
type
==
GGML_TYPE_F16
&&
mask
;
if
(
Q
->
ne
[
1
]
==
1
&&
Q
->
ne
[
0
]
%
(
2
*
WARP_SIZE
)
==
0
&&
!
mma_fast_for_bs1
)
{
const
bool
gqa_opt_applies
=
((
Q
->
ne
[
2
]
/
K
->
ne
[
2
])
%
2
==
0
)
&&
mask
;
// The mma-based kernels have GQA-specific optimizations
const
bool
mma_needs_data_conversion
=
K
->
type
!=
GGML_TYPE_F16
||
V
->
type
!=
GGML_TYPE_F16
;
const
bool
mma_faster_for_bs1
=
new_mma_available
(
cc
)
&&
gqa_opt_applies
&&
cc
<
GGML_CUDA_CC_ADA_LOVELACE
&&
!
mma_needs_data_conversion
;
const
bool
can_use_vector_kernel
=
Q
->
ne
[
0
]
%
(
2
*
warp_size
)
==
0
;
if
(
Q
->
ne
[
1
]
==
1
&&
can_use_vector_kernel
&&
!
mma_faster_for_bs1
)
{
if
(
prec
==
GGML_PREC_DEFAULT
)
{
ggml_cuda_flash_attn_ext_vec_f16
(
ctx
,
dst
);
return
;
}
else
if
(
Q
->
ne
[
0
]
<=
128
)
{
}
else
{
ggml_cuda_flash_attn_ext_vec_f32
(
ctx
,
dst
);
return
;
}
return
;
}
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
if
(
cc
==
GGML_CUDA_CC_VOLTA
)
{
if
(
fp16_mma_available
(
cc
)
&&
!
new_mma_available
(
cc
)
)
{
ggml_cuda_flash_attn_ext_wmma_f16
(
ctx
,
dst
);
return
;
}
...
...
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
View file @
943464cc
...
...
@@ -31,12 +31,14 @@
#include "ggml-cuda/rope.cuh"
#include "ggml-cuda/scale.cuh"
#include "ggml-cuda/softmax.cuh"
#include "ggml-cuda/ssm-conv.cuh"
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv
6
.cuh"
#include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml.h"
...
...
@@ -262,9 +264,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
id
,
prop
.
name
,
prop
.
gcnArchName
,
info
.
devices
[
id
].
cc
&
0xffff
,
device_vmm
?
"yes"
:
"no"
,
prop
.
warpSize
);
#elif defined(GGML_USE_MUSA)
// TODO: refine the .cc to reflect MUSA's actual CC capabilities
// FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
info
.
devices
[
id
].
warp_size
=
32
;
info
.
devices
[
id
].
smpbo
=
prop
.
sharedMemPerBlockOptin
;
info
.
devices
[
id
].
cc
=
100
*
prop
.
major
+
10
*
prop
.
minor
;
info
.
devices
[
id
].
cc
=
GGML_CUDA_CC_OFFSET_MTHREADS
+
prop
.
major
*
0x100
;
info
.
devices
[
id
].
cc
+=
prop
.
minor
*
0x10
;
GGML_LOG_INFO
(
" Device %d: %s, compute capability %d.%d, VMM: %s
\n
"
,
id
,
prop
.
name
,
prop
.
major
,
prop
.
minor
,
device_vmm
?
"yes"
:
"no"
);
#else
...
...
@@ -541,12 +545,12 @@ static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
return
ctx
->
dev_ptr
;
}
static
void
ggml_backend_cuda_buffer_init_tensor
(
ggml_backend_buffer_t
buffer
,
ggml_tensor
*
tensor
)
{
static
enum
ggml_status
ggml_backend_cuda_buffer_init_tensor
(
ggml_backend_buffer_t
buffer
,
ggml_tensor
*
tensor
)
{
ggml_backend_cuda_buffer_context
*
ctx
=
(
ggml_backend_cuda_buffer_context
*
)
buffer
->
context
;
if
(
tensor
->
view_src
!=
NULL
)
{
assert
(
tensor
->
view_src
->
buffer
->
buft
==
buffer
->
buft
);
return
;
return
GGML_STATUS_SUCCESS
;
}
if
(
ggml_is_quantized
(
tensor
->
type
)
&&
tensor
->
view_src
==
nullptr
&&
ggml_backend_buffer_get_usage
(
buffer
)
!=
GGML_BACKEND_BUFFER_USAGE_COMPUTE
)
{
...
...
@@ -559,6 +563,7 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g
CUDA_CHECK
(
cudaMemset
((
char
*
)
tensor
->
data
+
original_size
,
0
,
padded_size
-
original_size
));
}
}
return
GGML_STATUS_SUCCESS
;
}
static
void
ggml_backend_cuda_buffer_memset_tensor
(
ggml_backend_buffer_t
buffer
,
ggml_tensor
*
tensor
,
uint8_t
value
,
size_t
offset
,
size_t
size
)
{
...
...
@@ -794,7 +799,7 @@ static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buff
GGML_UNUSED
(
buffer
);
}
static
void
ggml_backend_cuda_split_buffer_init_tensor
(
ggml_backend_buffer_t
buffer
,
ggml_tensor
*
tensor
)
{
static
enum
ggml_status
ggml_backend_cuda_split_buffer_init_tensor
(
ggml_backend_buffer_t
buffer
,
ggml_tensor
*
tensor
)
{
GGML_ASSERT
(
tensor
->
view_src
==
nullptr
);
// views of split tensors are not supported
ggml_backend_cuda_split_buffer_context
*
ctx
=
(
ggml_backend_cuda_split_buffer_context
*
)
buffer
->
context
;
...
...
@@ -840,6 +845,7 @@ static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buf
}
}
tensor
->
extra
=
extra
;
return
GGML_STATUS_SUCCESS
;
}
static
void
ggml_backend_cuda_split_buffer_set_tensor
(
ggml_backend_buffer_t
buffer
,
ggml_tensor
*
tensor
,
const
void
*
data
,
size_t
offset
,
size_t
size
)
{
...
...
@@ -1187,11 +1193,39 @@ static void ggml_cuda_op_mul_mat_cublas(
// ldc == nrows of the matrix that cuBLAS writes into
int64_t
ldc
=
id
==
ctx
.
device
?
ne0
:
row_diff
;
const
int
c
ompute_capability
=
ggml_cuda_info
().
devices
[
id
].
cc
;
const
int
c
c
=
ggml_cuda_info
().
devices
[
id
].
cc
;
const
bool
use_fp16
=
(
src0
->
type
==
GGML_TYPE_F16
||
ggml_is_quantized
(
src0
->
type
))
&&
ggml_is_contiguous
(
src0
)
&&
row_diff
==
src0
->
ne
[
1
]
&&
dst
->
op_params
[
0
]
==
GGML_PREC_DEFAULT
;
if
(
compute_capability
>=
GGML_CUDA_CC_VOLTA
&&
use_fp16
)
{
if
(
src0
->
type
==
GGML_TYPE_BF16
&&
ggml_is_contiguous
(
src0
)
&&
row_diff
==
src0
->
ne
[
1
])
{
ggml_cuda_pool_alloc
<
nv_bfloat16
>
src1_as_bf16
(
ctx
.
pool
(
id
));
if
(
src1
->
type
!=
GGML_TYPE_BF16
)
{
const
to_bf16_cuda_t
to_bf16_cuda
=
ggml_get_to_bf16_cuda
(
src1
->
type
);
GGML_ASSERT
(
to_bf16_cuda
!=
nullptr
);
size_t
ne
=
src1_ncols
*
ne10
;
src1_as_bf16
.
alloc
(
ne
);
to_bf16_cuda
(
src1_ddf_i
,
src1_as_bf16
.
get
(),
ne
,
stream
);
}
const
nv_bfloat16
*
src1_ptr
=
src1
->
type
==
GGML_TYPE_BF16
?
(
const
nv_bfloat16
*
)
src1_ddf_i
:
src1_as_bf16
.
get
();
const
nv_bfloat16
*
src0_ptr
=
(
const
nv_bfloat16
*
)
src0_dd_i
;
ggml_cuda_pool_alloc
<
nv_bfloat16
>
dst_bf16
(
ctx
.
pool
(
id
),
row_diff
*
src1_ncols
);
const
float
alpha_f32
=
1.0
f
;
const
float
beta_f32
=
0.0
f
;
CUBLAS_CHECK
(
cublasSetStream
(
ctx
.
cublas_handle
(
id
),
stream
));
CUBLAS_CHECK
(
cublasGemmEx
(
ctx
.
cublas_handle
(
id
),
CUBLAS_OP_T
,
CUBLAS_OP_N
,
row_diff
,
src1_ncols
,
ne10
,
&
alpha_f32
,
src0_ptr
,
CUDA_R_16BF
,
ne00
,
src1_ptr
,
CUDA_R_16BF
,
ne10
,
&
beta_f32
,
dst_bf16
.
get
(),
CUDA_R_16BF
,
ldc
,
CUBLAS_COMPUTE_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
const
to_fp32_cuda_t
to_fp32_cuda
=
ggml_get_to_fp32_cuda
(
GGML_TYPE_BF16
);
to_fp32_cuda
(
dst_bf16
.
get
(),
dst_dd_i
,
row_diff
*
src1_ncols
,
stream
);
}
else
if
(((
GGML_CUDA_CC_IS_NVIDIA
(
cc
)
&&
cc
>=
GGML_CUDA_CC_VOLTA
)
||
GGML_CUDA_CC_IS_AMD
(
cc
))
&&
use_fp16
)
{
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
ggml_cuda_pool_alloc
<
half
>
src0_as_f16
(
ctx
.
pool
(
id
));
if
(
src0
->
type
!=
GGML_TYPE_F16
)
{
...
...
@@ -1215,7 +1249,7 @@ static void ggml_cuda_op_mul_mat_cublas(
CUBLAS_CHECK
(
cublasSetStream
(
ctx
.
cublas_handle
(
id
),
stream
));
if
(
GGML_CUDA_CC_IS_CDNA
(
c
ompute_capability
))
{
if
(
GGML_CUDA_CC_IS_CDNA
(
c
c
)
||
GGML_CUDA_CC_IS_RDNA4
(
cc
))
{
const
float
alpha
=
1.0
f
;
const
float
beta
=
0.0
f
;
CUBLAS_CHECK
(
...
...
@@ -1758,7 +1792,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
beta
=
&
beta_f32
;
}
if
(
GGML_CUDA_CC_IS_CDNA
(
ggml_cuda_info
().
devices
[
ctx
.
device
].
cc
))
{
int
id
=
ggml_cuda_get_device
();
const
int
cc
=
ggml_cuda_info
().
devices
[
id
].
cc
;
if
(
GGML_CUDA_CC_IS_CDNA
(
cc
)
||
GGML_CUDA_CC_IS_RDNA4
(
cc
))
{
cu_compute_type
=
CUBLAS_COMPUTE_32F
;
alpha
=
&
alpha_f32
;
beta
=
&
beta_f32
;
...
...
@@ -1835,7 +1871,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
}
#endif
if
(
dst
->
op_params
[
0
]
==
GGML_PREC_DEFAULT
)
{
if
(
dst
->
op_params
[
0
]
==
GGML_PREC_DEFAULT
&&
cu_data_type
==
CUDA_R_16F
)
{
const
to_fp32_cuda_t
to_fp32_cuda
=
ggml_get_to_fp32_cuda
(
GGML_TYPE_F16
);
to_fp32_cuda
(
dst_f16
.
get
(),
dst_ddf
,
ne_dst
,
main_stream
);
}
...
...
@@ -2148,6 +2184,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
break
;
case
GGML_OP_UNARY
:
switch
(
ggml_get_unary_op
(
dst
))
{
case
GGML_UNARY_OP_ABS
:
ggml_cuda_op_abs
(
ctx
,
dst
);
break
;
case
GGML_UNARY_OP_SGN
:
ggml_cuda_op_sgn
(
ctx
,
dst
);
break
;
case
GGML_UNARY_OP_NEG
:
ggml_cuda_op_neg
(
ctx
,
dst
);
break
;
...
...
@@ -2191,6 +2233,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case
GGML_OP_GROUP_NORM
:
ggml_cuda_op_group_norm
(
ctx
,
dst
);
break
;
case
GGML_OP_L2_NORM
:
ggml_cuda_op_l2_norm
(
ctx
,
dst
);
break
;
case
GGML_OP_CONCAT
:
ggml_cuda_op_concat
(
ctx
,
dst
);
break
;
...
...
@@ -2248,6 +2293,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case
GGML_OP_CLAMP
:
ggml_cuda_op_clamp
(
ctx
,
dst
);
break
;
case
GGML_OP_LOG
:
ggml_cuda_op_log
(
ctx
,
dst
);
break
;
case
GGML_OP_NONE
:
case
GGML_OP_RESHAPE
:
case
GGML_OP_VIEW
:
...
...
@@ -2284,6 +2332,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case
GGML_OP_SUM_ROWS
:
ggml_cuda_op_sum_rows
(
ctx
,
dst
);
break
;
case
GGML_OP_SSM_CONV
:
ggml_cuda_op_ssm_conv
(
ctx
,
dst
);
break
;
case
GGML_OP_SSM_SCAN
:
ggml_cuda_op_ssm_scan
(
ctx
,
dst
);
break
;
case
GGML_OP_ARGSORT
:
ggml_cuda_op_argsort
(
ctx
,
dst
);
break
;
...
...
@@ -2301,6 +2355,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case
GGML_OP_GATED_LINEAR_ATTN
:
ggml_cuda_op_gated_linear_attn
(
ctx
,
dst
);
break
;
case
GGML_OP_RWKV_WKV7
:
ggml_cuda_op_rwkv_wkv7
(
ctx
,
dst
);
break
;
case
GGML_OP_CROSS_ENTROPY_LOSS_BACK
:
ggml_cuda_cross_entropy_loss_back
(
ctx
,
dst
);
break
;
...
...
@@ -2568,7 +2625,7 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto
for
(
size_t
i
=
0
;
i
<
cuda_ctx
->
cuda_graph
->
num_nodes
;
i
++
)
{
if
(
count
(
ggml_cuda_cpy_fn_ptrs
.
begin
(),
ggml_cuda_cpy_fn_ptrs
.
end
(),
cuda_ctx
->
cuda_graph
->
params
[
i
].
func
)
>
0
)
{
char
**
updated_kernel_arg_ptr
=
cuda_ctx
->
cuda_graph
->
updated_kernel_arg
.
at
(
k
++
);
cuda_ctx
->
cuda_graph
->
params
[
i
].
kernelParams
[
1
]
=
updated_kernel_arg_ptr
;
*
(
void
**
)
cuda_ctx
->
cuda_graph
->
params
[
i
].
kernelParams
[
1
]
=
*
(
void
**
)
updated_kernel_arg_ptr
;
CUDA_CHECK
(
cudaGraphKernelNodeSetParams
(
cuda_ctx
->
cuda_graph
->
nodes
[
i
],
&
cuda_ctx
->
cuda_graph
->
params
[
i
]));
}
}
...
...
@@ -2607,13 +2664,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
static
void
update_cuda_graph_executable
(
ggml_backend_cuda_context
*
cuda_ctx
)
{
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo
result_info
;
#ifdef __HIP_PLATFORM_AMD__
hipGraphNode_t
errorNode
;
hipError_t
stat
=
hipGraphExecUpdate
(
cuda_ctx
->
cuda_graph
->
instance
,
cuda_ctx
->
cuda_graph
->
graph
,
&
errorNode
,
&
result_info
);
#else
cudaError_t
stat
=
cudaGraphExecUpdate
(
cuda_ctx
->
cuda_graph
->
instance
,
cuda_ctx
->
cuda_graph
->
graph
,
&
result_info
);
#endif
#else
cudaGraphNode_t
errorNode
;
cudaGraphExecUpdateResult
result_info
;
cudaError_t
stat
=
cudaGraphExecUpdate
(
cuda_ctx
->
cuda_graph
->
instance
,
cuda_ctx
->
cuda_graph
->
graph
,
&
errorNode
,
&
result_info
);
#endif // CUDART_VERSION >= 12000
if
(
stat
==
cudaErrorGraphExecUpdateFailure
)
{
#ifndef NDEBUG
GGML_LOG_DEBUG
(
"%s: CUDA graph update failed
\n
"
,
__func__
);
...
...
@@ -2968,6 +3027,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
switch
(
op
->
op
)
{
case
GGML_OP_UNARY
:
switch
(
ggml_get_unary_op
(
op
))
{
case
GGML_UNARY_OP_ABS
:
case
GGML_UNARY_OP_SGN
:
case
GGML_UNARY_OP_NEG
:
case
GGML_UNARY_OP_STEP
:
case
GGML_UNARY_OP_GELU
:
...
...
@@ -3071,6 +3132,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if
(
src0_type
==
GGML_TYPE_F32
&&
src1_type
==
GGML_TYPE_F32
)
{
return
true
;
}
if
(
src0_type
==
GGML_TYPE_F32
&&
src1_type
==
GGML_TYPE_BF16
)
{
return
true
;
}
if
(
src0_type
==
GGML_TYPE_F32
&&
src1_type
==
GGML_TYPE_F16
)
{
return
true
;
}
...
...
@@ -3150,10 +3214,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return
false
;
}
break
;
case
GGML_OP_SILU_BACK
:
return
ggml_is_contiguous
(
op
->
src
[
0
]);
return
ggml_is_contiguous
(
op
->
src
[
0
])
&&
op
->
src
[
0
]
->
type
==
GGML_TYPE_F32
;
break
;
case
GGML_OP_NORM
:
case
GGML_OP_RMS_NORM
:
case
GGML_OP_L2_NORM
:
return
true
;
case
GGML_OP_RMS_NORM_BACK
:
return
ggml_is_contiguous
(
op
->
src
[
0
])
&&
op
->
ne
[
0
]
%
WARP_SIZE
==
0
;
...
...
@@ -3174,6 +3239,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case
GGML_OP_SIN
:
case
GGML_OP_COS
:
case
GGML_OP_CLAMP
:
case
GGML_OP_LOG
:
case
GGML_OP_SSM_SCAN
:
case
GGML_OP_SSM_CONV
:
return
true
;
case
GGML_OP_CONT
:
return
op
->
src
[
0
]
->
type
!=
GGML_TYPE_BF16
;
...
...
@@ -3201,6 +3269,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case
GGML_OP_GROUP_NORM
:
return
ggml_is_contiguous
(
op
->
src
[
0
]);
case
GGML_OP_UPSCALE
:
return
op
->
src
[
0
]
->
type
==
GGML_TYPE_F32
&&
op
->
op_params
[
0
]
==
GGML_SCALE_MODE_NEAREST
;
case
GGML_OP_PAD
:
case
GGML_OP_UNPAD
:
case
GGML_OP_ARANGE
:
...
...
@@ -3208,11 +3277,22 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case
GGML_OP_LEAKY_RELU
:
case
GGML_OP_RWKV_WKV6
:
case
GGML_OP_GATED_LINEAR_ATTN
:
case
GGML_OP_RWKV_WKV7
:
return
true
;
case
GGML_OP_FLASH_ATTN_EXT
:
{
#ifndef FLASH_ATTN_AVAILABLE
return
false
;
#endif // FLASH_ATTN_AVAILABLE
if
(
op
->
src
[
1
]
->
ne
[
0
]
!=
op
->
src
[
2
]
->
ne
[
0
])
{
// different head sizes of K and V are not supported yet
return
false
;
}
if
(
op
->
src
[
0
]
->
ne
[
0
]
==
192
)
{
return
false
;
}
if
(
op
->
src
[
0
]
->
ne
[
3
]
!=
1
)
{
return
false
;
}
if
(
op
->
src
[
1
]
->
type
==
GGML_TYPE_BF16
||
op
->
src
[
2
]
->
type
==
GGML_TYPE_BF16
)
{
return
false
;
}
...
...
ml/backend/ggml/ggml/src/ggml-cuda/mma.cuh
View file @
943464cc
...
...
@@ -26,6 +26,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
asm
(
"movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
:
"=r"
(
ret
)
:
"r"
(
x
));
#else
GGML_UNUSED
(
x
);
NO_DEVICE_CODE
;
#endif // defined(NEW_MMA_AVAILABLE)
return
ret
;
...
...
@@ -178,6 +179,7 @@ namespace ggml_cuda_mma {
:
"l"
(
xs
));
#else
load_generic
(
xs0
,
stride
);
GGML_UNUSED
(
t
);
#endif // NEW_MMA_AVAILABLE
}
...
...
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
View file @
943464cc
...
...
@@ -27,8 +27,8 @@ void ggml_cuda_op_mul_mat_q(
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
// Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const
bool
use_stream_k
=
ggml_cuda_highest_compiled_arch
(
cc
)
>=
GGML_CUDA_CC_VOLTA
&&
cc
<
GGML_CUDA_CC_
OFFSET_AMD
&&
src1_ncols
==
ne11
;
const
bool
use_stream_k
=
GGML_CUDA_CC_IS_NVIDIA
(
cc
)
&&
ggml_cuda_highest_compiled_arch
(
cc
)
>=
GGML_CUDA_CC_
VOLTA
&&
src1_ncols
==
ne11
;
const
mmq_args
args
=
{
src0_dd_i
,
src1_ddq_i
,
dst_dd_i
,
ne00
,
row_diff
,
stride00
,
src1_padded_row_size
,
src1_ncols
,
ne11
,
nrows_dst
,
use_stream_k
};
switch
(
src0
->
type
)
{
...
...
@@ -145,7 +145,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return
true
;
#endif //GGML_CUDA_FORCE_MMQ
if
(
cc
<
GGML_CUDA_CC_
OFFSET_AMD
)
{
if
(
GGML_CUDA_CC_
IS_NVIDIA
(
cc
)
)
{
return
!
fp16_mma_hardware_available
(
cc
)
||
ne11
<
MMQ_DP4A_MAX_BATCH_SIZE
;
}
...
...
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cuh
View file @
943464cc
...
...
@@ -90,7 +90,7 @@ struct tile_x_sizes {
static
int
get_mmq_x_max_host
(
const
int
cc
)
{
return
new_mma_available
(
cc
)
?
128
:
ggml_cuda_highest_compiled_arch
(
cc
)
>=
GGML_CUDA_CC_VOLTA
&&
cc
<
GGML_CUDA_CC_OFFSET_AMD
?
GGML_CUDA_CC_IS_NVIDIA
(
cc
)
&&
ggml_cuda_highest_compiled_arch
(
cc
)
>=
GGML_CUDA_CC_VOLTA
?
#ifdef GGML_CUDA_FORCE_MMQ
128
:
64
;
#else
...
...
@@ -109,9 +109,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#ifdef GGML_CUDA_FORCE_MMQ
return
MMQ_DP4A_MAX_BATCH_SIZE
;
#else // GGML_CUDA_FORCE_MMQ
return
128
;
#else // GGML_CUDA_FORCE_MMQ
return
MMQ_DP4A_MAX_BATCH_SIZE
;
#endif // GGML_CUDA_FORCE_MMQ
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
...
...
@@ -123,8 +123,8 @@ static constexpr __device__ int get_mmq_x_max_device() {
}
static
int
get_mmq_y_host
(
const
int
cc
)
{
return
cc
>=
GGML_CUDA_CC_
OFFSET_AMD
?
(
GGML_CUDA_CC_IS_RDNA1
(
cc
)
?
64
:
128
)
:
(
ggml_cuda_highest_compiled_arch
(
cc
)
>=
GGML_CUDA_CC_VOLTA
?
128
:
64
);
return
GGML_CUDA_CC_
IS_AMD
(
cc
)
?
(
GGML_CUDA_CC_IS_RDNA1
(
cc
)
?
64
:
128
)
:
(
(
GGML_CUDA_CC_IS_NVIDIA
(
cc
)
&&
ggml_cuda_highest_compiled_arch
(
cc
)
>=
GGML_CUDA_CC_VOLTA
)
?
128
:
64
);
}
static
constexpr
__device__
int
get_mmq_y_device
()
{
...
...
@@ -945,7 +945,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
}
}
#else
GGML_UNUSED
(
x
);
GGML_UNUSED
(
y
);
GGML_UNUSED
(
sum
);
GGML_UNUSED
(
x
);
GGML_UNUSED
(
y
);
GGML_UNUSED
(
sum
);
GGML_UNUSED
(
k00
);
NO_DEVICE_CODE
;
#endif // NEW_MMA_AVAILABLE
}
...
...
@@ -1024,7 +1024,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
}
#pragma unroll
for
(
int
k01
=
0
;
k01
<
WARP_SIZE
;
k01
+=
QR2_K
*
VDR_Q2_K_Q8_1_MMQ
)
{
for
(
int
k01
=
0
;
k01
<
WARP_SIZE
/
2
;
k01
+=
QR2_K
*
VDR_Q2_K_Q8_1_MMQ
)
{
const
int
k0
=
k00
+
k01
;
#pragma unroll
...
...
@@ -1035,19 +1035,34 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
if
(
k01
<
WARP_SIZE
/
2
)
{
constexpr
int
ns
=
2
;
sum
[
j0
/
nwarps
*
mmq_y
/
WARP_SIZE
+
i0
/
WARP_SIZE
]
+=
vec_dot_q2_K_q8_1_impl_mmq
<
ns
>
(
&
x_qs
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
k0
],
&
y_qs
[
j
*
MMQ_TILE_Y_K
+
k01
],
&
x_dm
[
i
*
(
WARP_SIZE
+
1
)
+
k0
/
4
],
k01
<
WARP_SIZE
/
2
?
y_df
[
j0
/
nwarps
].
x
:
y_df
[
j0
/
nwarps
].
y
,
&
y_ds
[
j
*
MMQ_TILE_Y_K
+
(
1
+
k01
/
QI8_1
)]);
}
else
{
constexpr
int
ns
=
1
;
sum
[
j0
/
nwarps
*
mmq_y
/
WARP_SIZE
+
i0
/
WARP_SIZE
]
+=
vec_dot_q2_K_q8_1_impl_mmq
<
ns
>
(
&
x_qs
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
k0
],
&
y_qs
[
j
*
MMQ_TILE_Y_K
+
k01
],
&
x_dm
[
i
*
(
WARP_SIZE
+
1
)
+
k0
/
4
],
k01
<
WARP_SIZE
/
2
?
y_df
[
j0
/
nwarps
].
x
:
y_df
[
j0
/
nwarps
].
y
,
&
y_ds
[
j
*
MMQ_TILE_Y_K
+
(
1
+
k01
/
QI8_1
)]);
}
constexpr
int
ns
=
2
;
sum
[
j0
/
nwarps
*
mmq_y
/
WARP_SIZE
+
i0
/
WARP_SIZE
]
+=
vec_dot_q2_K_q8_1_impl_mmq
<
ns
>
(
&
x_qs
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
k0
],
&
y_qs
[
j
*
MMQ_TILE_Y_K
+
k01
],
&
x_dm
[
i
*
(
WARP_SIZE
+
1
)
+
k0
/
4
],
k01
<
WARP_SIZE
/
2
?
y_df
[
j0
/
nwarps
].
x
:
y_df
[
j0
/
nwarps
].
y
,
&
y_ds
[
j
*
MMQ_TILE_Y_K
+
(
1
+
k01
/
QI8_1
)]);
}
}
}
// Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
// As a workaround 2 separate loops are used instead.
#pragma unroll
for
(
int
k01
=
WARP_SIZE
/
2
;
k01
<
WARP_SIZE
;
k01
+=
QR2_K
*
VDR_Q2_K_Q8_1_MMQ
)
{
const
int
k0
=
k00
+
k01
;
#pragma unroll
for
(
int
j0
=
0
;
j0
<
mmq_x
;
j0
+=
nwarps
)
{
const
int
j
=
j0
+
threadIdx
.
y
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
mmq_y
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
constexpr
int
ns
=
1
;
sum
[
j0
/
nwarps
*
mmq_y
/
WARP_SIZE
+
i0
/
WARP_SIZE
]
+=
vec_dot_q2_K_q8_1_impl_mmq
<
ns
>
(
&
x_qs
[
i
*
(
2
*
WARP_SIZE
+
1
)
+
k0
],
&
y_qs
[
j
*
MMQ_TILE_Y_K
+
k01
],
&
x_dm
[
i
*
(
WARP_SIZE
+
1
)
+
k0
/
4
],
k01
<
WARP_SIZE
/
2
?
y_df
[
j0
/
nwarps
].
x
:
y_df
[
j0
/
nwarps
].
y
,
&
y_ds
[
j
*
MMQ_TILE_Y_K
+
(
1
+
k01
/
QI8_1
)]);
}
}
}
...
...
@@ -1176,7 +1191,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
}
}
#else
GGML_UNUSED
(
x
);
GGML_UNUSED
(
y
);
GGML_UNUSED
(
sum
);
GGML_UNUSED
(
x
);
GGML_UNUSED
(
y
);
GGML_UNUSED
(
sum
);
GGML_UNUSED
(
k00
);
NO_DEVICE_CODE
;
#endif // NEW_MMA_AVAILABLE
}
...
...
@@ -1253,7 +1268,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const
float
d
=
bxi
->
d
;
#pragma unroll
for
(
int
l
=
0
;
l
<
sizeof
(
int
);
++
l
)
{
for
(
int
l
=
0
;
l
<
int
(
sizeof
(
int
)
)
;
++
l
)
{
x_df
[
i
*
MMQ_MMA_TILE_X_K_Q3_K
+
sizeof
(
int
)
*
(
threadIdx
.
x
%
(
WARP_SIZE
/
8
))
+
l
]
=
d
*
sc8
[
l
];
}
#else
...
...
@@ -1376,7 +1391,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const
half2
dm
=
bxi
->
dm
*
make_half2
(
1.0
f
,
-
1.0
f
);
#pragma unroll
for
(
int
l
=
0
;
l
<
sizeof
(
int
);
++
l
)
{
for
(
int
l
=
0
;
l
<
int
(
sizeof
(
int
)
)
;
++
l
)
{
x_dm
[
i
*
MMQ_MMA_TILE_X_K_Q8_1
+
sizeof
(
int
)
*
ksc
+
l
]
=
dm
*
make_half2
(
sc8
[
l
],
m8
[
l
]);
}
}
...
...
@@ -1517,7 +1532,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const
half2
dm
=
bxi
->
dm
*
make_half2
(
1.0
f
,
-
1.0
f
);
#pragma unroll
for
(
int
l
=
0
;
l
<
sizeof
(
int
);
++
l
)
{
for
(
int
l
=
0
;
l
<
int
(
sizeof
(
int
)
)
;
++
l
)
{
x_dm
[
i
*
MMQ_MMA_TILE_X_K_Q8_1
+
sizeof
(
int
)
*
ksc
+
l
]
=
dm
*
make_half2
(
sc8
[
l
],
m8
[
l
]);
}
}
...
...
@@ -1810,7 +1825,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
}
}
#else
GGML_UNUSED
(
x
);
GGML_UNUSED
(
y
);
GGML_UNUSED
(
sum
);
GGML_UNUSED
(
x
);
GGML_UNUSED
(
y
);
GGML_UNUSED
(
sum
);
GGML_UNUSED
(
k00
);
NO_DEVICE_CODE
;
#endif // NEW_MMA_AVAILABLE
}
...
...
@@ -2570,6 +2585,8 @@ static __device__ void mul_mat_q_process_tile(
}
else
{
write_back
(
sum
,
dst
+
jt
*
mmq_x
*
ne0
+
it
*
mmq_y
,
ne0
,
tile_x_max_i
,
tile_y_max_j
);
}
GGML_UNUSED
(
ne00
);
GGML_UNUSED
(
ne10
);
}
...
...
@@ -2695,7 +2712,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
const
int
it
=
(
kbc_stop
-
jt
*
(
blocks_per_ne00
*
nty
))
/
blocks_per_ne00
;
// Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
if
(
it
!=
blockIdx
.
x
||
jt
!=
blockIdx
.
y
)
{
if
(
(
unsigned
)
it
!=
blockIdx
.
x
||
(
unsigned
)
jt
!=
blockIdx
.
y
)
{
continue
;
}
...
...
@@ -2772,14 +2789,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
const
int
shmem
=
mmq_get_shmem
<
type
>
(
mmq_x
,
mmq_y
,
cc
);
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
&& !defined(GGML_USE_MUSA)
static
bool
shmem_limit_raised
[
GGML_CUDA_MAX_DEVICES
]
=
{
false
};
if
(
!
shmem_limit_raised
[
id
])
{
CUDA_CHECK
(
cudaFuncSetAttribute
(
mul_mat_q
<
type
,
mmq_x
,
MMQ_NWARPS
,
false
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
CUDA_CHECK
(
cudaFuncSetAttribute
(
mul_mat_q
<
type
,
mmq_x
,
MMQ_NWARPS
,
true
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem
));
shmem_limit_raised
[
id
]
=
true
;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
&& !defined(GGML_USE_MUSA)
const
int
nty
=
(
args
.
ne01
+
mmq_y
-
1
)
/
mmq_y
;
const
int
ntx
=
(
args
.
ne11
+
mmq_x
-
1
)
/
mmq_x
;
...
...
@@ -2825,14 +2842,13 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
template
<
ggml_type
type
>
void
mul_mat_q_case
(
ggml_backend_cuda_context
&
ctx
,
const
mmq_args
&
args
,
cudaStream_t
stream
)
{
const
int
id
=
ggml_cuda_get_device
();
const
int
nsm
=
ggml_cuda_info
().
devices
[
id
].
nsm
;
const
int
cc
=
ggml_cuda_info
().
devices
[
id
].
cc
;
const
int
smpbo
=
ggml_cuda_info
().
devices
[
id
].
smpbo
;
const
int
mmq_x_max
=
get_mmq_x_max_host
(
cc
);
const
int
mmq_y
=
get_mmq_y_host
(
cc
);
const
int
block_num_y
=
(
args
.
ne01
+
mmq_y
-
1
)
/
mmq_y
;
const
bool
use_stream_k
=
ggml_cuda_highest_compiled_arch
(
cc
)
>=
GGML_CUDA_CC_VOLTA
&&
cc
<
GGML_CUDA_CC_OFFSET_AMD
;
const
bool
use_stream_k
=
GGML_CUDA_CC_IS_NVIDIA
(
cc
)
&&
ggml_cuda_highest_compiled_arch
(
cc
)
>=
GGML_CUDA_CC_VOLTA
;
int
mmq_x_best
=
0
;
int
nparts_best
=
INT_MAX
;
...
...
ml/backend/ggml/ggml/src/ggml-cuda/mmv.cu
View file @
943464cc
...
...
@@ -29,7 +29,7 @@ static __global__ void mul_mat_vec(
__syncthreads
();
}
float
sumf
;
float
sumf
=
0.0
f
;
if
constexpr
(
std
::
is_same
<
T
,
half
>::
value
)
{
const
half2
*
x2
=
(
const
half2
*
)
x
;
...
...
ml/backend/ggml/ggml/src/ggml-cuda/mmvq.cu
View file @
943464cc
...
...
@@ -47,11 +47,89 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
1
;
}
enum
mmvq_parameter_table_id
{
MMVQ_PARAMETERS_GENERIC
=
0
,
MMVQ_PARAMETERS_GCN
,
MMVQ_PARAMETERS_RDNA2
};
static
constexpr
__device__
mmvq_parameter_table_id
get_device_table_id
()
{
#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
return
MMVQ_PARAMETERS_RDNA2
;
#elif defined(GCN) || defined(CDNA)
return
MMVQ_PARAMETERS_GCN
;
#else
return
MMVQ_PARAMETERS_GENERIC
;
#endif
}
static
__host__
mmvq_parameter_table_id
get_device_table_id
(
int
cc
)
{
if
(
GGML_CUDA_CC_IS_RDNA2
(
cc
)
||
GGML_CUDA_CC_IS_RDNA3
(
cc
)
||
GGML_CUDA_CC_IS_RDNA4
(
cc
))
{
return
MMVQ_PARAMETERS_RDNA2
;
}
if
(
GGML_CUDA_CC_IS_GCN
(
cc
)
||
GGML_CUDA_CC_IS_CDNA
(
cc
))
{
return
MMVQ_PARAMETERS_GCN
;
}
return
MMVQ_PARAMETERS_GENERIC
;
}
static
constexpr
__host__
__device__
int
calc_nwarps
(
int
ncols_y
,
mmvq_parameter_table_id
table_id
)
{
if
(
table_id
==
MMVQ_PARAMETERS_GENERIC
)
{
switch
(
ncols_y
)
{
case
1
:
case
2
:
case
3
:
case
4
:
return
4
;
case
5
:
case
6
:
case
7
:
case
8
:
return
2
;
default:
return
1
;
}
}
else
if
(
table_id
==
MMVQ_PARAMETERS_GCN
)
{
switch
(
ncols_y
)
{
case
1
:
case
2
:
case
3
:
case
4
:
return
2
;
case
5
:
case
6
:
case
7
:
case
8
:
default:
return
1
;
}
}
return
1
;
}
static
constexpr
__host__
__device__
int
calc_rows_per_block
(
int
ncols_y
,
int
table_id
)
{
if
(
table_id
==
MMVQ_PARAMETERS_GENERIC
||
table_id
==
MMVQ_PARAMETERS_GCN
)
{
switch
(
ncols_y
)
{
case
1
:
return
1
;
case
2
:
case
3
:
case
4
:
case
5
:
case
6
:
case
7
:
case
8
:
return
2
;
default:
return
1
;
}
}
return
1
;
}
template
<
ggml_type
type
,
int
ncols_y
>
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
// tell the compiler to use as many registers as it wants, see nwarps definition below
__launch_bounds__
((
ncols_y
<=
4
?
4
:
2
)
*
WARP_SIZE
,
1
)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__
(
calc_nwarps
(
ncols_y
,
get_device_table_id
())
*
ggml_cuda_get_physical_warp_size
(),
1
)
static
__global__
void
mul_mat_vec_q
(
const
void
*
__restrict__
vx
,
const
void
*
__restrict__
vy
,
float
*
__restrict__
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
nrows_y
,
const
int
nrows_dst
)
{
...
...
@@ -59,25 +137,21 @@ static __global__ void mul_mat_vec_q(
constexpr
int
qk
=
ggml_cuda_type_traits
<
type
>::
qk
;
constexpr
int
qi
=
ggml_cuda_type_traits
<
type
>::
qi
;
constexpr
int
vdr
=
get_vdr_mmvq
(
type
);
constexpr
mmvq_parameter_table_id
table_id
=
get_device_table_id
();
constexpr
int
nwarps
=
calc_nwarps
(
ncols_y
,
table_id
);
constexpr
int
rows_per_cuda_block
=
calc_rows_per_block
(
ncols_y
,
table_id
);
constexpr
int
warp_size
=
ggml_cuda_get_physical_warp_size
();
constexpr
vec_dot_q_cuda_t
vec_dot_q_cuda
=
get_vec_dot_q_cuda
(
type
);
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
constexpr
int
nwarps
=
1
;
constexpr
int
rows_per_cuda_block
=
1
;
#else
constexpr
int
nwarps
=
ncols_y
<=
4
?
4
:
2
;
constexpr
int
rows_per_cuda_block
=
ncols_y
==
1
?
1
:
2
;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
const
int
tid
=
WARP_SIZE
*
threadIdx
.
y
+
threadIdx
.
x
;
const
int
tid
=
warp_size
*
threadIdx
.
y
+
threadIdx
.
x
;
const
int
row0
=
rows_per_cuda_block
*
blockIdx
.
x
;
const
int
blocks_per_row_x
=
ncols_x
/
qk
;
const
int
blocks_per_col_y
=
nrows_y
/
QK8_1
;
constexpr
int
blocks_per_iter
=
vdr
*
nwarps
*
WARP_SIZE
/
qi
;
constexpr
int
blocks_per_iter
=
vdr
*
nwarps
*
warp_size
/
qi
;
// partial sum for each thread
float
tmp
[
ncols_y
][
rows_per_cuda_block
]
=
{
0.0
f
};
// partial sum for each thread
float
tmp
[
ncols_y
][
rows_per_cuda_block
]
=
{
{
0.0
f
}
}
;
const
block_q8_1
*
y
=
(
const
block_q8_1
*
)
vy
;
...
...
@@ -96,7 +170,7 @@ static __global__ void mul_mat_vec_q(
}
}
__shared__
float
tmp_shared
[
nwarps
-
1
>
0
?
nwarps
-
1
:
1
][
ncols_y
][
rows_per_cuda_block
][
WARP_SIZE
];
__shared__
float
tmp_shared
[
nwarps
-
1
>
0
?
nwarps
-
1
:
1
][
ncols_y
][
rows_per_cuda_block
][
warp_size
];
if
(
threadIdx
.
y
>
0
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols_y
;
++
j
)
{
...
...
@@ -120,13 +194,22 @@ static __global__ void mul_mat_vec_q(
for
(
int
l
=
0
;
l
<
nwarps
-
1
;
++
l
)
{
tmp
[
j
][
i
]
+=
tmp_shared
[
l
][
j
][
i
][
threadIdx
.
x
];
}
tmp
[
j
][
i
]
=
warp_reduce_sum
(
tmp
[
j
][
i
]);
tmp
[
j
][
i
]
=
warp_reduce_sum
<
warp_size
>
(
tmp
[
j
][
i
]);
}
if
(
threadIdx
.
x
<
rows_per_cuda_block
&&
(
rows_per_cuda_block
==
1
||
row0
+
threadIdx
.
x
<
nrows_dst
))
{
if
(
threadIdx
.
x
<
rows_per_cuda_block
&&
(
rows_per_cuda_block
==
1
||
row0
+
threadIdx
.
x
<
(
unsigned
)
nrows_dst
))
{
dst
[
j
*
nrows_dst
+
row0
+
threadIdx
.
x
]
=
tmp
[
j
][
threadIdx
.
x
];
}
}
GGML_UNUSED
(
nrows_x
);
}
static
std
::
pair
<
dim3
,
dim3
>
calc_launch_params
(
const
int
ncols_y
,
const
int
nrows_x
,
const
int
warp_size
,
const
mmvq_parameter_table_id
table_id
)
{
const
int64_t
nblocks
=
(
nrows_x
+
calc_rows_per_block
(
ncols_y
,
table_id
)
-
1
)
/
calc_rows_per_block
(
ncols_y
,
table_id
);
const
dim3
block_nums
(
nblocks
,
1
,
1
);
const
dim3
block_dims
(
warp_size
,
calc_nwarps
(
ncols_y
,
table_id
),
1
);
return
{
block_nums
,
block_dims
};
}
template
<
ggml_type
type
>
...
...
@@ -137,65 +220,67 @@ static void mul_mat_vec_q_cuda(
GGML_ASSERT
(
ncols_x
%
ggml_blck_size
(
type
)
==
0
);
GGML_ASSERT
(
ncols_y
<=
MMVQ_MAX_BATCH_SIZE
);
int
id
=
ggml_cuda_get_device
();
int64_t
nwarps
=
1
;
int64_t
rows_per_cuda_block
=
1
;
if
(
ggml_cuda_info
().
devices
[
id
].
cc
<
GGML_CUDA_CC_RDNA2
)
{
// NVIDIA and AMD older than RDNA2
switch
(
ncols_y
)
{
case
1
:
nwarps
=
4
;
rows_per_cuda_block
=
1
;
break
;
case
2
:
case
3
:
case
4
:
nwarps
=
4
;
rows_per_cuda_block
=
2
;
break
;
case
5
:
case
6
:
case
7
:
case
8
:
nwarps
=
2
;
rows_per_cuda_block
=
2
;
break
;
default:
GGML_ABORT
(
"fatal error"
);
break
;
}
}
const
int64_t
nblocks
=
(
nrows_x
+
rows_per_cuda_block
-
1
)
/
rows_per_cuda_block
;
const
dim3
block_nums
(
nblocks
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
nwarps
,
1
);
const
int
device
=
ggml_cuda_get_device
();
const
int
warp_size
=
ggml_cuda_info
().
devices
[
device
].
warp_size
;
const
mmvq_parameter_table_id
table_id
=
get_device_table_id
(
ggml_cuda_info
().
devices
[
device
].
cc
);
switch
(
ncols_y
)
{
case
1
:
mul_mat_vec_q
<
type
,
1
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
{
constexpr
int
c_ncols_y
=
1
;
std
::
pair
<
dim3
,
dim3
>
dims
=
calc_launch_params
(
c_ncols_y
,
nrows_x
,
warp_size
,
table_id
);
mul_mat_vec_q
<
type
,
c_ncols_y
><<<
dims
.
first
,
dims
.
second
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
break
;
}
case
2
:
mul_mat_vec_q
<
type
,
2
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
{
constexpr
int
c_ncols_y
=
2
;
std
::
pair
<
dim3
,
dim3
>
dims
=
calc_launch_params
(
c_ncols_y
,
nrows_x
,
warp_size
,
table_id
);
mul_mat_vec_q
<
type
,
c_ncols_y
><<<
dims
.
first
,
dims
.
second
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
break
;
}
case
3
:
mul_mat_vec_q
<
type
,
3
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
{
constexpr
int
c_ncols_y
=
3
;
std
::
pair
<
dim3
,
dim3
>
dims
=
calc_launch_params
(
c_ncols_y
,
nrows_x
,
warp_size
,
table_id
);
mul_mat_vec_q
<
type
,
c_ncols_y
><<<
dims
.
first
,
dims
.
second
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
break
;
}
case
4
:
mul_mat_vec_q
<
type
,
4
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
{
constexpr
int
c_ncols_y
=
4
;
std
::
pair
<
dim3
,
dim3
>
dims
=
calc_launch_params
(
c_ncols_y
,
nrows_x
,
warp_size
,
table_id
);
mul_mat_vec_q
<
type
,
c_ncols_y
><<<
dims
.
first
,
dims
.
second
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
break
;
}
case
5
:
mul_mat_vec_q
<
type
,
5
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
{
constexpr
int
c_ncols_y
=
5
;
std
::
pair
<
dim3
,
dim3
>
dims
=
calc_launch_params
(
c_ncols_y
,
nrows_x
,
warp_size
,
table_id
);
mul_mat_vec_q
<
type
,
c_ncols_y
><<<
dims
.
first
,
dims
.
second
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
break
;
}
case
6
:
mul_mat_vec_q
<
type
,
6
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
{
constexpr
int
c_ncols_y
=
6
;
std
::
pair
<
dim3
,
dim3
>
dims
=
calc_launch_params
(
c_ncols_y
,
nrows_x
,
warp_size
,
table_id
);
mul_mat_vec_q
<
type
,
c_ncols_y
><<<
dims
.
first
,
dims
.
second
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
break
;
}
case
7
:
mul_mat_vec_q
<
type
,
7
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
{
constexpr
int
c_ncols_y
=
7
;
std
::
pair
<
dim3
,
dim3
>
dims
=
calc_launch_params
(
c_ncols_y
,
nrows_x
,
warp_size
,
table_id
);
mul_mat_vec_q
<
type
,
c_ncols_y
><<<
dims
.
first
,
dims
.
second
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
break
;
}
case
8
:
mul_mat_vec_q
<
type
,
8
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
{
constexpr
int
c_ncols_y
=
8
;
std
::
pair
<
dim3
,
dim3
>
dims
=
calc_launch_params
(
c_ncols_y
,
nrows_x
,
warp_size
,
table_id
);
mul_mat_vec_q
<
type
,
c_ncols_y
><<<
dims
.
first
,
dims
.
second
,
0
,
stream
>>>
(
vx
,
vy
,
dst
,
ncols_x
,
nrows_x
,
nrows_y
,
nrows_dst
);
break
;
}
default:
GGML_ABORT
(
"fatal error"
);
break
;
...
...
ml/backend/ggml/ggml/src/ggml-cuda/norm.cu
View file @
943464cc
...
...
@@ -201,6 +201,85 @@ static __global__ void rms_norm_back_f32(
}
}
// template <int block_size>
// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
// const int tid = threadIdx.x;
// float tmp = 0.0f; // partial sum for thread in warp
// for (int col = tid; col < ncols; col += block_size) {
// const float xi = x[row*ncols + col];
// tmp += xi * xi;
// }
// // sum up partial sums
// tmp = warp_reduce_sum(tmp);
// if (block_size > WARP_SIZE) {
// __shared__ float s_sum[32];
// int warp_id = threadIdx.x / WARP_SIZE;
// int lane_id = threadIdx.x % WARP_SIZE;
// if (lane_id == 0) {
// s_sum[warp_id] = tmp;
// }
// __syncthreads();
// tmp = s_sum[lane_id];
// tmp = warp_reduce_sum(tmp);
// }
// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
// for (int col = tid; col < ncols; col += block_size) {
// dst[row*ncols + col] = scale * x[row*ncols + col];
// }
// }
template
<
int
block_size
>
static
__global__
void
l2_norm_f32
(
const
float
*
x
,
float
*
dst
,
const
int
ncols
,
const
int64_t
stride_row
,
const
int64_t
stride_channel
,
const
int64_t
stride_sample
,
const
float
eps
)
{
const
int
nrows
=
gridDim
.
x
;
const
int
nchannels
=
gridDim
.
y
;
const
int
row
=
blockIdx
.
x
;
const
int
channel
=
blockIdx
.
y
;
const
int
sample
=
blockIdx
.
z
;
const
int
tid
=
threadIdx
.
x
;
x
+=
sample
*
stride_sample
+
channel
*
stride_channel
+
row
*
stride_row
;
dst
+=
((
sample
*
nchannels
+
channel
)
*
nrows
+
row
)
*
ncols
;
float
tmp
=
0.0
f
;
// partial sum for thread in warp
for
(
int
col
=
tid
;
col
<
ncols
;
col
+=
block_size
)
{
const
float
xi
=
x
[
col
];
tmp
+=
xi
*
xi
;
}
// sum up partial sums
tmp
=
warp_reduce_sum
(
tmp
);
if
constexpr
(
block_size
>
WARP_SIZE
)
{
static_assert
(
block_size
==
1024
,
"unexpected block_size"
);
__shared__
float
s_sum
[
32
];
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
lane_id
=
threadIdx
.
x
%
WARP_SIZE
;
if
(
lane_id
==
0
)
{
s_sum
[
warp_id
]
=
tmp
;
}
__syncthreads
();
tmp
=
s_sum
[
lane_id
];
tmp
=
warp_reduce_sum
(
tmp
);
}
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
const
float
scale
=
rsqrtf
(
fmaxf
(
tmp
,
eps
*
eps
));
for
(
int
col
=
tid
;
col
<
ncols
;
col
+=
block_size
)
{
dst
[
col
]
=
scale
*
x
[
col
];
}
}
static
void
norm_f32_cuda
(
const
float
*
x
,
float
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nchannels
,
const
int
nsamples
,
const
int64_t
stride_row
,
const
int64_t
stride_channel
,
const
int64_t
stride_sample
,
const
float
eps
,
cudaStream_t
stream
)
{
...
...
@@ -248,6 +327,19 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
}
}
static
void
l2_norm_f32_cuda
(
const
float
*
x
,
float
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
int
nchannels
,
const
int
nsamples
,
const
int64_t
stride_row
,
const
int64_t
stride_channel
,
const
int64_t
stride_sample
,
const
float
eps
,
cudaStream_t
stream
)
{
const
dim3
blocks_num
(
nrows
,
nchannels
,
nsamples
);
if
(
ncols
<
1024
)
{
const
dim3
block_dims
(
WARP_SIZE
,
1
,
1
);
l2_norm_f32
<
WARP_SIZE
><<<
blocks_num
,
block_dims
,
0
,
stream
>>>
(
x
,
dst
,
ncols
,
stride_row
,
stride_channel
,
stride_sample
,
eps
);
}
else
{
const
dim3
block_dims
(
1024
,
1
,
1
);
l2_norm_f32
<
1024
><<<
blocks_num
,
block_dims
,
0
,
stream
>>>
(
x
,
dst
,
ncols
,
stride_row
,
stride_channel
,
stride_sample
,
eps
);
}
}
void
ggml_cuda_op_norm
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
src0
=
dst
->
src
[
0
];
const
float
*
src0_d
=
(
const
float
*
)
src0
->
data
;
...
...
@@ -340,3 +432,27 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
rms_norm_back_f32_cuda
(
grad_d
,
src0f_d
,
dst_d
,
ne00
,
nrows
,
eps
,
stream
);
}
void
ggml_cuda_op_l2_norm
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
src0
=
dst
->
src
[
0
];
const
float
*
src0_d
=
(
const
float
*
)
src0
->
data
;
float
*
dst_d
=
(
float
*
)
dst
->
data
;
cudaStream_t
stream
=
ctx
.
stream
();
GGML_ASSERT
(
src0
->
type
==
GGML_TYPE_F32
);
GGML_ASSERT
(
dst
->
type
==
GGML_TYPE_F32
);
GGML_TENSOR_UNARY_OP_LOCALS
;
float
eps
;
memcpy
(
&
eps
,
dst
->
op_params
,
sizeof
(
float
));
GGML_ASSERT
(
eps
>=
0.0
f
);
const
size_t
ts0
=
ggml_type_size
(
src0
->
type
);
GGML_ASSERT
(
nb00
==
ts0
);
const
int64_t
s01
=
nb01
/
ts0
;
const
int64_t
s02
=
nb02
/
ts0
;
const
int64_t
s03
=
nb03
/
ts0
;
l2_norm_f32_cuda
(
src0_d
,
dst_d
,
ne00
,
ne01
,
ne02
,
ne03
,
s01
,
s02
,
s03
,
eps
,
stream
);
}
ml/backend/ggml/ggml/src/ggml-cuda/norm.cuh
View file @
943464cc
...
...
@@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void
ggml_cuda_op_rms_norm
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
void
ggml_cuda_op_rms_norm_back
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
void
ggml_cuda_op_l2_norm
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
ml/backend/ggml/ggml/src/ggml-cuda/pad.cu
View file @
943464cc
...
...
@@ -14,7 +14,7 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
nidx
+
blockIdx
.
y
*
ne0
+
blockIdx
.
z
*
ne0
*
gridDim
.
y
;
if
(
nidx
<
ne00
&&
blockIdx
.
y
<
ne01
&&
blockIdx
.
z
<
ne02
*
ne03
)
{
if
(
nidx
<
ne00
&&
blockIdx
.
y
<
(
unsigned
)
ne01
&&
blockIdx
.
z
<
(
unsigned
)(
ne02
*
ne03
)
)
{
int
offset_src
=
nidx
+
blockIdx
.
y
*
ne00
+
...
...
ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cu
0 → 100644
View file @
943464cc
#include "ssm-conv.cuh"
template
<
size_t
split_d_inner
,
size_t
d_conv
>
static
__global__
void
ssm_conv_f32
(
const
float
*
__restrict__
src0
,
const
float
*
__restrict__
src1
,
const
int
src0_nb0
,
const
int
src0_nb1
,
const
int
src0_nb2
,
const
int
src1_nb1
,
float
*
__restrict__
dst
,
const
int
dst_nb0
,
const
int
dst_nb1
,
const
int
dst_nb2
,
const
int64_t
n_t
)
{
GGML_UNUSED
(
src0_nb0
);
const
int
tid
=
threadIdx
.
x
;
const
int
bidx
=
blockIdx
.
x
;
const
int
bidy
=
blockIdx
.
y
;
const
float
*
x_block
=
(
const
float
*
)
((
const
char
*
)
src0
+
bidx
*
src0_nb2
+
bidy
*
split_d_inner
*
src0_nb1
);
const
float
*
w_block
=
(
const
float
*
)
((
const
char
*
)
src1
+
bidy
*
split_d_inner
*
src1_nb1
);
float
*
y_block
=
(
float
*
)
((
char
*
)
dst
+
bidx
*
dst_nb2
+
bidy
*
split_d_inner
*
dst_nb0
);
const
int
stride_x
=
src0_nb1
/
sizeof
(
float
);
const
int
stride_w
=
src1_nb1
/
sizeof
(
float
);
const
int
stride_y
=
dst_nb1
/
sizeof
(
float
);
float
x
[
d_conv
]
=
{
0.0
f
};
float
w
[
d_conv
]
=
{
0.0
f
};
#pragma unroll
for
(
size_t
j
=
0
;
j
<
d_conv
;
j
++
)
{
w
[
j
]
=
w_block
[
tid
*
stride_w
+
j
];
}
for
(
int64_t
i
=
0
;
i
<
n_t
;
i
++
)
{
float
sumf
=
0.0
f
;
if
(
i
==
0
)
{
for
(
size_t
j
=
0
;
j
<
d_conv
;
j
++
)
{
x
[
j
]
=
x_block
[
tid
*
stride_x
+
j
];
}
}
else
{
x
[(
i
-
1
)
%
d_conv
]
=
x_block
[
tid
*
stride_x
+
i
+
d_conv
-
1
];
}
#pragma unroll
for
(
size_t
j
=
0
;
j
<
d_conv
;
j
++
)
{
sumf
+=
x
[(
i
+
j
)
%
d_conv
]
*
w
[
j
];
}
y_block
[
i
*
stride_y
+
tid
]
=
sumf
;
}
}
template
<
size_t
split_d_inner
,
size_t
d_conv
,
int64_t
split_n_t
>
static
__global__
void
ssm_conv_long_token_f32
(
const
float
*
__restrict__
src0
,
const
float
*
__restrict__
src1
,
const
int
src0_nb0
,
const
int
src0_nb1
,
const
int
src0_nb2
,
const
int
src1_nb1
,
float
*
__restrict__
dst
,
const
int
dst_nb0
,
const
int
dst_nb1
,
const
int
dst_nb2
,
const
int64_t
n_t
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
bidx
=
blockIdx
.
x
;
const
int
bidy
=
blockIdx
.
y
;
const
int
bidz
=
blockIdx
.
z
;
const
float
*
x_block
=
(
const
float
*
)
((
const
char
*
)
src0
+
bidx
*
src0_nb2
+
bidy
*
split_d_inner
*
src0_nb1
+
bidz
*
split_n_t
*
src0_nb0
);
const
float
*
w_block
=
(
const
float
*
)
((
const
char
*
)
src1
+
bidy
*
split_d_inner
*
src1_nb1
);
float
*
y_block
=
(
float
*
)
((
char
*
)
dst
+
bidx
*
dst_nb2
+
bidz
*
split_n_t
*
dst_nb1
+
bidy
*
split_d_inner
*
dst_nb0
);
const
int
stride_x
=
src0_nb1
/
sizeof
(
float
);
const
int
stride_w
=
src1_nb1
/
sizeof
(
float
);
const
int
stride_y
=
dst_nb1
/
sizeof
(
float
);
float
x
[
d_conv
]
=
{
0.0
f
};
float
w
[
d_conv
]
=
{
0.0
f
};
#pragma unroll
for
(
size_t
j
=
0
;
j
<
d_conv
;
j
++
)
{
w
[
j
]
=
w_block
[
tid
*
stride_w
+
j
];
}
#pragma unroll
for
(
int64_t
i
=
0
;
i
<
split_n_t
;
i
++
)
{
if
(
bidz
*
split_n_t
+
i
<
n_t
)
{
float
sumf
=
0.0
f
;
if
(
i
==
0
)
{
for
(
size_t
j
=
0
;
j
<
d_conv
;
j
++
)
{
x
[
j
]
=
x_block
[
tid
*
stride_x
+
j
];
}
}
else
{
x
[(
i
-
1
)
%
d_conv
]
=
x_block
[
tid
*
stride_x
+
i
+
d_conv
-
1
];
}
#pragma unroll
for
(
size_t
j
=
0
;
j
<
d_conv
;
j
++
)
{
sumf
+=
x
[(
i
+
j
)
%
d_conv
]
*
w
[
j
];
}
y_block
[
i
*
stride_y
+
tid
]
=
sumf
;
}
}
}
static
void
ssm_conv_f32_cuda
(
const
float
*
src0
,
const
float
*
src1
,
const
int
src0_nb0
,
const
int
src0_nb1
,
const
int
src0_nb2
,
const
int
src1_nb1
,
float
*
dst
,
const
int
dst_nb0
,
const
int
dst_nb1
,
const
int
dst_nb2
,
const
int64_t
nc
,
const
int64_t
nr
,
const
int64_t
n_t
,
const
int64_t
n_s
,
cudaStream_t
stream
)
{
const
int
threads
=
128
;
GGML_ASSERT
(
nr
%
threads
==
0
);
if
(
n_t
<=
32
)
{
const
dim3
blocks
(
n_s
,
(
nr
+
threads
-
1
)
/
threads
,
1
);
if
(
nc
==
4
)
{
ssm_conv_f32
<
threads
,
4
><<<
blocks
,
threads
,
0
,
stream
>>>
(
src0
,
src1
,
src0_nb0
,
src0_nb1
,
src0_nb2
,
src1_nb1
,
dst
,
dst_nb0
,
dst_nb1
,
dst_nb2
,
n_t
);
}
else
{
GGML_ABORT
(
"Only support kernel size = 4 now."
);
}
}
else
{
if
(
nc
==
4
)
{
const
int64_t
split_n_t
=
32
;
dim3
blocks
(
n_s
,
(
nr
+
threads
-
1
)
/
threads
,
(
n_t
+
split_n_t
-
1
)
/
split_n_t
);
ssm_conv_long_token_f32
<
threads
,
4
,
split_n_t
><<<
blocks
,
threads
,
0
,
stream
>>>
(
src0
,
src1
,
src0_nb0
,
src0_nb1
,
src0_nb2
,
src1_nb1
,
dst
,
dst_nb0
,
dst_nb1
,
dst_nb2
,
n_t
);
}
else
{
GGML_ABORT
(
"Only support kernel size = 4 right now."
);
}
}
}
void
ggml_cuda_op_ssm_conv
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
struct
ggml_tensor
*
src0
=
dst
->
src
[
0
];
// conv_x
const
struct
ggml_tensor
*
src1
=
dst
->
src
[
1
];
// conv1d.weight
const
int64_t
nc
=
src1
->
ne
[
0
];
// d_conv
const
int64_t
nr
=
src0
->
ne
[
1
];
// d_inner
const
int64_t
n_t
=
dst
->
ne
[
1
];
// tokens per sequence
const
int64_t
n_s
=
dst
->
ne
[
2
];
// number of sequences in the batch
GGML_ASSERT
(
dst
->
ne
[
0
]
==
nr
);
GGML_ASSERT
(
src0
->
nb
[
0
]
==
sizeof
(
float
));
GGML_ASSERT
(
src1
->
nb
[
0
]
==
sizeof
(
float
));
GGML_ASSERT
(
src0
->
nb
[
1
]
==
src0
->
ne
[
0
]
*
sizeof
(
float
));
const
float
*
src0_d
=
(
const
float
*
)
src0
->
data
;
const
float
*
src1_d
=
(
const
float
*
)
src1
->
data
;
float
*
dst_d
=
(
float
*
)
dst
->
data
;
cudaStream_t
stream
=
ctx
.
stream
();
GGML_ASSERT
(
src0
->
type
==
GGML_TYPE_F32
);
GGML_ASSERT
(
dst
->
type
==
GGML_TYPE_F32
);
ssm_conv_f32_cuda
(
src0_d
,
src1_d
,
src0
->
nb
[
0
],
src0
->
nb
[
1
],
src0
->
nb
[
2
],
src1
->
nb
[
1
],
dst_d
,
dst
->
nb
[
0
],
dst
->
nb
[
1
],
dst
->
nb
[
2
],
nc
,
nr
,
n_t
,
n_s
,
stream
);
}
ml/backend/ggml/ggml/src/ggml-cuda/ssm-conv.cuh
0 → 100644
View file @
943464cc
#include "common.cuh"
void
ggml_cuda_op_ssm_conv
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cu
0 → 100644
View file @
943464cc
#include "ssm-scan.cuh"
template
<
size_t
splitD
,
size_t
N
>
__global__
void
__launch_bounds__
(
splitD
,
2
)
ssm_scan_f32
(
const
float
*
__restrict__
src0
,
const
float
*
__restrict__
src1
,
const
float
*
__restrict__
src2
,
const
float
*
__restrict__
src3
,
const
float
*
__restrict__
src4
,
const
float
*
__restrict__
src5
,
const
int
src0_nb1
,
const
int
src0_nb2
,
const
int
src1_nb0
,
const
int
src1_nb1
,
const
int
src1_nb2
,
const
int
src1_nb3
,
const
int
src2_nb0
,
const
int
src2_nb1
,
const
int
src2_nb2
,
const
int
src3_nb1
,
const
int
src4_nb1
,
const
int
src4_nb2
,
const
int
src5_nb1
,
const
int
src5_nb2
,
float
*
__restrict__
dst
,
const
int64_t
L
)
{
GGML_UNUSED
(
src1_nb0
);
GGML_UNUSED
(
src2_nb0
);
const
int
bidx
=
blockIdx
.
x
;
// split along B
const
int
bidy
=
blockIdx
.
y
;
// split along D
const
int
tid
=
threadIdx
.
x
;
const
int
wid
=
tid
/
32
;
const
int
wtid
=
tid
%
32
;
extern
__shared__
float
smem
[];
const
int
stride_sA
=
N
+
1
;
const
int
stride_ss0
=
N
+
1
;
float
*
smem_A
=
smem
;
float
*
smem_s0
=
smem_A
+
splitD
*
stride_sA
;
const
float
*
s0_block
=
(
const
float
*
)
((
const
char
*
)
src0
+
bidx
*
src0_nb2
+
bidy
*
splitD
*
src0_nb1
);
const
float
*
x_block
=
(
const
float
*
)
((
const
char
*
)
src1
+
(
bidx
*
src1_nb2
)
+
bidy
*
splitD
*
sizeof
(
float
));
const
float
*
dt_block
=
(
const
float
*
)
((
const
char
*
)
src2
+
(
bidx
*
src2_nb2
)
+
bidy
*
splitD
*
sizeof
(
float
));
const
float
*
A_block
=
(
const
float
*
)
((
const
char
*
)
src3
+
bidy
*
splitD
*
src3_nb1
);
const
float
*
B_block
=
(
const
float
*
)
((
const
char
*
)
src4
+
(
bidx
*
src4_nb2
));
const
float
*
C_block
=
(
const
float
*
)
((
const
char
*
)
src5
+
(
bidx
*
src5_nb2
));
float
*
y_block
=
(
float
*
)
((
char
*
)
dst
+
(
bidx
*
src1_nb2
)
+
bidy
*
splitD
*
sizeof
(
float
));
float
*
s_block
=
(
float
*
)
((
char
*
)
dst
+
src1_nb3
+
bidx
*
src0_nb2
+
bidy
*
splitD
*
src0_nb1
);
const
int
stride_s0
=
src0_nb1
/
sizeof
(
float
);
const
int
stride_x
=
src1_nb1
/
sizeof
(
float
);
const
int
stride_dt
=
src2_nb1
/
sizeof
(
float
);
const
int
stride_A
=
src3_nb1
/
sizeof
(
float
);
const
int
stride_B
=
src4_nb1
/
sizeof
(
float
);
const
int
stride_C
=
src5_nb1
/
sizeof
(
float
);
const
int
stride_s
=
stride_s0
;
const
int
stride_y
=
stride_x
;
// can N not be 16? for example 32?
if
(
N
==
16
)
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
splitD
/
4
;
i
+=
2
)
{
float
value
=
A_block
[(
wid
*
warpSize
+
i
)
*
stride_A
+
wtid
];
// todo: bank conflict
// I am always confused with how to use the swizzling method to solve
// bank conflit. Hoping somebody can tell me.
smem_A
[(
wid
*
warpSize
+
i
)
*
stride_sA
+
wtid
+
((
wtid
/
16
)
>
0
?
1
:
0
)]
=
value
;
}
#pragma unroll
for
(
size_t
i
=
0
;
i
<
splitD
/
4
;
i
+=
2
)
{
float
value
=
s0_block
[(
wid
*
warpSize
+
i
)
*
stride_s0
+
wtid
];
smem_s0
[(
wid
*
warpSize
+
i
)
*
stride_ss0
+
wtid
+
((
wtid
/
16
)
>
0
?
1
:
0
)]
=
value
;
}
}
__syncthreads
();
for
(
int64_t
i
=
0
;
i
<
L
;
i
++
)
{
float
dt_soft_plus
=
dt_block
[
i
*
stride_dt
+
tid
];
if
(
dt_soft_plus
<=
20.0
f
)
{
dt_soft_plus
=
log1pf
(
exp
(
dt_soft_plus
));
}
float
x_dt
=
x_block
[
i
*
stride_x
+
tid
]
*
dt_soft_plus
;
float
sumf
=
0.0
f
;
#pragma unroll
for
(
size_t
j
=
0
;
j
<
N
;
j
++
)
{
float
state
=
(
smem_s0
[
tid
*
stride_ss0
+
j
]
*
expf
(
dt_soft_plus
*
smem_A
[
tid
*
stride_sA
+
j
]))
+
(
B_block
[
i
*
stride_B
+
j
]
*
x_dt
);
sumf
+=
state
*
C_block
[
i
*
stride_C
+
j
];
if
(
i
==
L
-
1
)
{
s_block
[
tid
*
stride_s
+
j
]
=
state
;
}
else
{
smem_s0
[
tid
*
stride_ss0
+
j
]
=
state
;
}
}
__syncthreads
();
y_block
[
i
*
stride_y
+
tid
]
=
sumf
;
}
}
static
void
ssm_scan_f32_cuda
(
const
float
*
src0
,
const
float
*
src1
,
const
float
*
src2
,
const
float
*
src3
,
const
float
*
src4
,
const
float
*
src5
,
const
int
src0_nb1
,
const
int
src0_nb2
,
const
int
src1_nb0
,
const
int
src1_nb1
,
const
int
src1_nb2
,
const
int
src1_nb3
,
const
int
src2_nb0
,
const
int
src2_nb1
,
const
int
src2_nb2
,
const
int
src3_nb1
,
const
int
src4_nb1
,
const
int
src4_nb2
,
const
int
src5_nb1
,
const
int
src5_nb2
,
float
*
dst
,
const
int64_t
N
,
const
int64_t
D
,
const
int64_t
L
,
const
int64_t
B
,
cudaStream_t
stream
)
{
const
int
threads
=
128
;
// todo: consider D cannot be divided,does this situation exist?
GGML_ASSERT
(
D
%
threads
==
0
);
const
dim3
blocks
(
B
,
(
D
+
threads
-
1
)
/
threads
,
1
);
const
int
smem_size
=
(
threads
*
(
N
+
1
)
*
2
)
*
sizeof
(
float
);
if
(
N
==
16
)
{
ssm_scan_f32
<
128
,
16
><<<
blocks
,
threads
,
smem_size
,
stream
>>>
(
src0
,
src1
,
src2
,
src3
,
src4
,
src5
,
src0_nb1
,
src0_nb2
,
src1_nb0
,
src1_nb1
,
src1_nb2
,
src1_nb3
,
src2_nb0
,
src2_nb1
,
src2_nb2
,
src3_nb1
,
src4_nb1
,
src4_nb2
,
src5_nb1
,
src5_nb2
,
dst
,
L
);
}
else
{
GGML_ABORT
(
"doesn't support N!=16."
);
}
}
void
ggml_cuda_op_ssm_scan
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
struct
ggml_tensor
*
src0
=
dst
->
src
[
0
];
// s
const
struct
ggml_tensor
*
src1
=
dst
->
src
[
1
];
// x
const
struct
ggml_tensor
*
src2
=
dst
->
src
[
2
];
// dt
const
struct
ggml_tensor
*
src3
=
dst
->
src
[
3
];
// A
const
struct
ggml_tensor
*
src4
=
dst
->
src
[
4
];
// B
const
struct
ggml_tensor
*
src5
=
dst
->
src
[
5
];
// C
// const int64_t d_state = src0->ne[0];
// const int64_t d_inner = src0->ne[1];
// const int64_t l = src1->ne[1];
// const int64_t b = src0->ne[2];
const
int64_t
nc
=
src0
->
ne
[
0
];
// d_state
const
int64_t
nr
=
src0
->
ne
[
1
];
// d_inner
const
int64_t
n_t
=
src1
->
ne
[
1
];
// number of tokens per sequence
const
int64_t
n_s
=
src0
->
ne
[
2
];
// number of sequences in the batch
GGML_ASSERT
(
ggml_nelements
(
src1
)
+
ggml_nelements
(
src0
)
==
ggml_nelements
(
dst
));
GGML_ASSERT
(
src0
->
nb
[
0
]
==
sizeof
(
float
));
GGML_ASSERT
(
src1
->
nb
[
0
]
==
sizeof
(
float
));
GGML_ASSERT
(
src2
->
nb
[
0
]
==
sizeof
(
float
));
GGML_ASSERT
(
src3
->
nb
[
0
]
==
sizeof
(
float
));
GGML_ASSERT
(
src4
->
nb
[
0
]
==
sizeof
(
float
));
GGML_ASSERT
(
src5
->
nb
[
0
]
==
sizeof
(
float
));
// required for the dot product between s and C
GGML_ASSERT
(
src0
->
nb
[
1
]
==
src0
->
ne
[
0
]
*
sizeof
(
float
));
// required for per-sequence offsets for states
GGML_ASSERT
(
src0
->
nb
[
2
]
==
src0
->
ne
[
0
]
*
src0
->
ne
[
1
]
*
sizeof
(
float
));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT
(
src1
->
nb
[
3
]
==
src1
->
ne
[
0
]
*
src1
->
ne
[
1
]
*
src1
->
ne
[
2
]
*
sizeof
(
float
));
const
float
*
src0_d
=
(
const
float
*
)
src0
->
data
;
const
float
*
src1_d
=
(
const
float
*
)
src1
->
data
;
const
float
*
src2_d
=
(
const
float
*
)
src2
->
data
;
const
float
*
src3_d
=
(
const
float
*
)
src3
->
data
;
const
float
*
src4_d
=
(
const
float
*
)
src4
->
data
;
const
float
*
src5_d
=
(
const
float
*
)
src5
->
data
;
float
*
dst_d
=
(
float
*
)
dst
->
data
;
cudaStream_t
stream
=
ctx
.
stream
();
GGML_ASSERT
(
src0
->
type
==
GGML_TYPE_F32
);
GGML_ASSERT
(
dst
->
type
==
GGML_TYPE_F32
);
ssm_scan_f32_cuda
(
src0_d
,
src1_d
,
src2_d
,
src3_d
,
src4_d
,
src5_d
,
src0
->
nb
[
1
],
src0
->
nb
[
2
],
src1
->
nb
[
0
],
src1
->
nb
[
1
],
src1
->
nb
[
2
],
src1
->
nb
[
3
],
src2
->
nb
[
0
],
src2
->
nb
[
1
],
src2
->
nb
[
2
],
src3
->
nb
[
1
],
src4
->
nb
[
1
],
src4
->
nb
[
2
],
src5
->
nb
[
1
],
src5
->
nb
[
2
],
dst_d
,
nc
,
nr
,
n_t
,
n_s
,
stream
);
}
ml/backend/ggml/ggml/src/ggml-cuda/ssm-scan.cuh
0 → 100644
View file @
943464cc
#include "common.cuh"
void
ggml_cuda_op_ssm_scan
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
Prev
1
…
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment