Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
d3ad6274
Commit
d3ad6274
authored
Nov 12, 2024
by
xuxzh1
🎱
Browse files
init
parent
97b02a89
Changes
193
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3561 deletions
+0
-3561
llm/llama.cpp/ggml-cuda/cpy.cu
llm/llama.cpp/ggml-cuda/cpy.cu
+0
-490
llm/llama.cpp/ggml-cuda/cpy.cuh
llm/llama.cpp/ggml-cuda/cpy.cuh
+0
-9
llm/llama.cpp/ggml-cuda/dequantize.cuh
llm/llama.cpp/ggml-cuda/dequantize.cuh
+0
-103
llm/llama.cpp/ggml-cuda/diagmask.cu
llm/llama.cpp/ggml-cuda/diagmask.cu
+0
-40
llm/llama.cpp/ggml-cuda/diagmask.cuh
llm/llama.cpp/ggml-cuda/diagmask.cuh
+0
-5
llm/llama.cpp/ggml-cuda/dmmv.cu
llm/llama.cpp/ggml-cuda/dmmv.cu
+0
-662
llm/llama.cpp/ggml-cuda/dmmv.cuh
llm/llama.cpp/ggml-cuda/dmmv.cuh
+0
-18
llm/llama.cpp/ggml-cuda/fattn-common.cuh
llm/llama.cpp/ggml-cuda/fattn-common.cuh
+0
-162
llm/llama.cpp/ggml-cuda/fattn-tile-f16.cu
llm/llama.cpp/ggml-cuda/fattn-tile-f16.cu
+0
-316
llm/llama.cpp/ggml-cuda/fattn-tile-f16.cuh
llm/llama.cpp/ggml-cuda/fattn-tile-f16.cuh
+0
-3
llm/llama.cpp/ggml-cuda/fattn-tile-f32.cu
llm/llama.cpp/ggml-cuda/fattn-tile-f32.cu
+0
-309
llm/llama.cpp/ggml-cuda/fattn-tile-f32.cuh
llm/llama.cpp/ggml-cuda/fattn-tile-f32.cuh
+0
-3
llm/llama.cpp/ggml-cuda/fattn-vec-f16.cu
llm/llama.cpp/ggml-cuda/fattn-vec-f16.cu
+0
-330
llm/llama.cpp/ggml-cuda/fattn-vec-f16.cuh
llm/llama.cpp/ggml-cuda/fattn-vec-f16.cuh
+0
-5
llm/llama.cpp/ggml-cuda/fattn-vec-f32.cu
llm/llama.cpp/ggml-cuda/fattn-vec-f32.cu
+0
-279
llm/llama.cpp/ggml-cuda/fattn-vec-f32.cuh
llm/llama.cpp/ggml-cuda/fattn-vec-f32.cuh
+0
-3
llm/llama.cpp/ggml-cuda/fattn.cu
llm/llama.cpp/ggml-cuda/fattn.cu
+0
-638
llm/llama.cpp/ggml-cuda/fattn.cuh
llm/llama.cpp/ggml-cuda/fattn.cuh
+0
-3
llm/llama.cpp/ggml-cuda/getrows.cu
llm/llama.cpp/ggml-cuda/getrows.cu
+0
-178
llm/llama.cpp/ggml-cuda/getrows.cuh
llm/llama.cpp/ggml-cuda/getrows.cuh
+0
-5
No files found.
Too many changes to show.
To preserve performance only
193 of 193+
files are displayed.
Plain diff
Email patch
llm/llama.cpp/ggml-cuda/cpy.cu
deleted
100644 → 0
View file @
97b02a89
#include "cpy.cuh"
typedef
void
(
*
cpy_kernel_t
)(
const
char
*
cx
,
char
*
cdst
);
static
__device__
void
cpy_1_f32_f32
(
const
char
*
cxi
,
char
*
cdsti
)
{
const
float
*
xi
=
(
const
float
*
)
cxi
;
float
*
dsti
=
(
float
*
)
cdsti
;
*
dsti
=
*
xi
;
}
static
__device__
void
cpy_1_f32_f16
(
const
char
*
cxi
,
char
*
cdsti
)
{
const
float
*
xi
=
(
const
float
*
)
cxi
;
half
*
dsti
=
(
half
*
)
cdsti
;
*
dsti
=
__float2half
(
*
xi
);
}
static
__device__
void
cpy_1_f16_f16
(
const
char
*
cxi
,
char
*
cdsti
)
{
const
half
*
xi
=
(
const
half
*
)
cxi
;
half
*
dsti
=
(
half
*
)
cdsti
;
*
dsti
=
*
xi
;
}
static
__device__
void
cpy_1_f16_f32
(
const
char
*
cxi
,
char
*
cdsti
)
{
const
half
*
xi
=
(
const
half
*
)
cxi
;
float
*
dsti
=
(
float
*
)
cdsti
;
*
dsti
=
*
xi
;
}
template
<
cpy_kernel_t
cpy_1
>
static
__global__
__launch_bounds__
(
1024
)
void
cpy_f32_f16
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
)
{
const
int64_t
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
i
>=
ne
)
{
return
;
}
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
// then combine those indices with the corresponding byte offsets to get the total offsets
const
int64_t
i03
=
i
/
(
ne00
*
ne01
*
ne02
);
const
int64_t
i02
=
(
i
-
i03
*
ne00
*
ne01
*
ne02
)
/
(
ne00
*
ne01
);
const
int64_t
i01
=
(
i
-
i03
*
ne00
*
ne01
*
ne02
-
i02
*
ne01
*
ne00
)
/
ne00
;
const
int64_t
i00
=
i
-
i03
*
ne00
*
ne01
*
ne02
-
i02
*
ne01
*
ne00
-
i01
*
ne00
;
const
int64_t
x_offset
=
i00
*
nb00
+
i01
*
nb01
+
i02
*
nb02
+
i03
*
nb03
;
const
int64_t
i13
=
i
/
(
ne10
*
ne11
*
ne12
);
const
int64_t
i12
=
(
i
-
i13
*
ne10
*
ne11
*
ne12
)
/
(
ne10
*
ne11
);
const
int64_t
i11
=
(
i
-
i13
*
ne10
*
ne11
*
ne12
-
i12
*
ne10
*
ne11
)
/
ne10
;
const
int64_t
i10
=
i
-
i13
*
ne10
*
ne11
*
ne12
-
i12
*
ne10
*
ne11
-
i11
*
ne10
;
const
int64_t
dst_offset
=
i10
*
nb10
+
i11
*
nb11
+
i12
*
nb12
+
i13
*
nb13
;
cpy_1
(
cx
+
x_offset
,
cdst
+
dst_offset
);
}
static
__device__
void
cpy_blck_f32_q8_0
(
const
char
*
cxi
,
char
*
cdsti
)
{
const
float
*
xi
=
(
const
float
*
)
cxi
;
block_q8_0
*
dsti
=
(
block_q8_0
*
)
cdsti
;
float
amax
=
0.0
f
;
// absolute max
for
(
int
j
=
0
;
j
<
QK8_0
;
j
++
)
{
const
float
v
=
xi
[
j
];
amax
=
fmaxf
(
amax
,
fabsf
(
v
));
}
const
float
d
=
amax
/
((
1
<<
7
)
-
1
);
const
float
id
=
d
?
1.0
f
/
d
:
0.0
f
;
dsti
->
d
=
d
;
for
(
int
j
=
0
;
j
<
QK8_0
;
++
j
)
{
const
float
x0
=
xi
[
j
]
*
id
;
dsti
->
qs
[
j
]
=
roundf
(
x0
);
}
}
static
__device__
void
cpy_blck_f32_q4_0
(
const
char
*
cxi
,
char
*
cdsti
)
{
const
float
*
xi
=
(
const
float
*
)
cxi
;
block_q4_0
*
dsti
=
(
block_q4_0
*
)
cdsti
;
float
amax
=
0.0
f
;
float
vmax
=
0.0
f
;
for
(
int
j
=
0
;
j
<
QK4_0
;
++
j
)
{
const
float
v
=
xi
[
j
];
if
(
amax
<
fabsf
(
v
))
{
amax
=
fabsf
(
v
);
vmax
=
v
;
}
}
const
float
d
=
vmax
/
-
8
;
const
float
id
=
d
?
1.0
f
/
d
:
0.0
f
;
dsti
->
d
=
d
;
for
(
int
j
=
0
;
j
<
QK4_0
/
2
;
++
j
)
{
const
float
x0
=
xi
[
0
+
j
]
*
id
;
const
float
x1
=
xi
[
QK4_0
/
2
+
j
]
*
id
;
const
uint8_t
xi0
=
min
(
15
,
(
int8_t
)(
x0
+
8.5
f
));
const
uint8_t
xi1
=
min
(
15
,
(
int8_t
)(
x1
+
8.5
f
));
dsti
->
qs
[
j
]
=
xi0
;
dsti
->
qs
[
j
]
|=
xi1
<<
4
;
}
}
static
__device__
void
cpy_blck_f32_q4_1
(
const
char
*
cxi
,
char
*
cdsti
)
{
const
float
*
xi
=
(
const
float
*
)
cxi
;
block_q4_1
*
dsti
=
(
block_q4_1
*
)
cdsti
;
float
vmin
=
FLT_MAX
;
float
vmax
=
-
FLT_MAX
;
for
(
int
j
=
0
;
j
<
QK4_1
;
++
j
)
{
const
float
v
=
xi
[
j
];
if
(
v
<
vmin
)
vmin
=
v
;
if
(
v
>
vmax
)
vmax
=
v
;
}
const
float
d
=
(
vmax
-
vmin
)
/
((
1
<<
4
)
-
1
);
const
float
id
=
d
?
1.0
f
/
d
:
0.0
f
;
dsti
->
dm
.
x
=
d
;
dsti
->
dm
.
y
=
vmin
;
for
(
int
j
=
0
;
j
<
QK4_1
/
2
;
++
j
)
{
const
float
x0
=
(
xi
[
0
+
j
]
-
vmin
)
*
id
;
const
float
x1
=
(
xi
[
QK4_1
/
2
+
j
]
-
vmin
)
*
id
;
const
uint8_t
xi0
=
min
(
15
,
(
int8_t
)(
x0
+
0.5
f
));
const
uint8_t
xi1
=
min
(
15
,
(
int8_t
)(
x1
+
0.5
f
));
dsti
->
qs
[
j
]
=
xi0
;
dsti
->
qs
[
j
]
|=
xi1
<<
4
;
}
}
static
__device__
void
cpy_blck_f32_q5_0
(
const
char
*
cxi
,
char
*
cdsti
)
{
const
float
*
xi
=
(
const
float
*
)
cxi
;
block_q5_0
*
dsti
=
(
block_q5_0
*
)
cdsti
;
float
amax
=
0.0
f
;
float
vmax
=
0.0
f
;
for
(
int
j
=
0
;
j
<
QK5_0
;
++
j
)
{
const
float
v
=
xi
[
j
];
if
(
amax
<
fabsf
(
v
))
{
amax
=
fabsf
(
v
);
vmax
=
v
;
}
}
const
float
d
=
vmax
/
-
16
;
const
float
id
=
d
?
1.0
f
/
d
:
0.0
f
;
dsti
->
d
=
d
;
uint32_t
qh
=
0
;
for
(
int
j
=
0
;
j
<
QK5_0
/
2
;
++
j
)
{
const
float
x0
=
xi
[
0
+
j
]
*
id
;
const
float
x1
=
xi
[
QK5_0
/
2
+
j
]
*
id
;
const
uint8_t
xi0
=
min
(
31
,
(
int8_t
)(
x0
+
16.5
f
));
const
uint8_t
xi1
=
min
(
31
,
(
int8_t
)(
x1
+
16.5
f
));
dsti
->
qs
[
j
]
=
(
xi0
&
0xf
)
|
((
xi1
&
0xf
)
<<
4
);
qh
|=
((
xi0
&
0x10u
)
>>
4
)
<<
(
j
+
0
);
qh
|=
((
xi1
&
0x10u
)
>>
4
)
<<
(
j
+
QK5_0
/
2
);
}
memcpy
(
dsti
->
qh
,
&
qh
,
sizeof
(
qh
));
}
static
__device__
void
cpy_blck_f32_q5_1
(
const
char
*
cxi
,
char
*
cdsti
)
{
const
float
*
xi
=
(
const
float
*
)
cxi
;
block_q5_1
*
dsti
=
(
block_q5_1
*
)
cdsti
;
float
min
=
xi
[
0
];
float
max
=
xi
[
0
];
for
(
int
j
=
1
;
j
<
QK5_1
;
++
j
)
{
const
float
v
=
xi
[
j
];
min
=
v
<
min
?
v
:
min
;
max
=
v
>
max
?
v
:
max
;
}
const
float
d
=
(
max
-
min
)
/
31
;
const
float
id
=
d
?
1.0
f
/
d
:
0.0
f
;
dsti
->
dm
.
x
=
d
;
dsti
->
dm
.
y
=
min
;
uint32_t
qh
=
0
;
for
(
int
j
=
0
;
j
<
QK5_1
/
2
;
++
j
)
{
const
float
x0
=
(
xi
[
0
+
j
]
-
min
)
*
id
;
const
float
x1
=
(
xi
[
QK5_1
/
2
+
j
]
-
min
)
*
id
;
const
uint8_t
xi0
=
(
uint8_t
)(
x0
+
0.5
f
);
const
uint8_t
xi1
=
(
uint8_t
)(
x1
+
0.5
f
);
dsti
->
qs
[
j
]
=
(
xi0
&
0xf
)
|
((
xi1
&
0xf
)
<<
4
);
qh
|=
((
xi0
&
0x10u
)
>>
4
)
<<
(
j
+
0
);
qh
|=
((
xi1
&
0x10u
)
>>
4
)
<<
(
j
+
QK5_1
/
2
);
}
memcpy
(
dsti
->
qh
,
&
qh
,
sizeof
(
qh
));
}
static
__device__
__forceinline__
int
best_index_int8
(
int
n
,
const
int8_t
*
val
,
float
x
)
{
if
(
x
<=
val
[
0
])
return
0
;
if
(
x
>=
val
[
n
-
1
])
return
n
-
1
;
int
ml
=
0
,
mu
=
n
-
1
;
while
(
mu
-
ml
>
1
)
{
int
mav
=
(
ml
+
mu
)
/
2
;
if
(
x
<
val
[
mav
])
mu
=
mav
;
else
ml
=
mav
;
}
return
x
-
val
[
mu
-
1
]
<
val
[
mu
]
-
x
?
mu
-
1
:
mu
;
}
static
__device__
void
cpy_blck_f32_iq4_nl
(
const
char
*
cxi
,
char
*
cdsti
)
{
const
float
*
xi
=
(
const
float
*
)
cxi
;
block_iq4_nl
*
dsti
=
(
block_iq4_nl
*
)
cdsti
;
float
amax
=
0.0
f
;
float
vmax
=
0.0
f
;
for
(
int
j
=
0
;
j
<
QK4_NL
;
++
j
)
{
const
float
v
=
xi
[
j
];
if
(
amax
<
fabsf
(
v
))
{
amax
=
fabsf
(
v
);
vmax
=
v
;
}
}
float
d
=
vmax
/
kvalues_iq4nl
[
0
];
const
float
id
=
d
?
1.0
f
/
d
:
0.0
f
;
float
sumqx
=
0
,
sumq2
=
0
;
for
(
int
j
=
0
;
j
<
QK4_NL
/
2
;
++
j
)
{
const
float
x0
=
xi
[
0
+
j
]
*
id
;
const
float
x1
=
xi
[
QK4_NL
/
2
+
j
]
*
id
;
const
uint8_t
xi0
=
best_index_int8
(
16
,
kvalues_iq4nl
,
x0
);
const
uint8_t
xi1
=
best_index_int8
(
16
,
kvalues_iq4nl
,
x1
);
dsti
->
qs
[
j
]
=
xi0
|
(
xi1
<<
4
);
const
float
v0
=
kvalues_iq4nl
[
xi0
];
const
float
v1
=
kvalues_iq4nl
[
xi1
];
const
float
w0
=
xi
[
0
+
j
]
*
xi
[
0
+
j
];
const
float
w1
=
xi
[
QK4_NL
/
2
+
j
]
*
xi
[
QK4_NL
/
2
+
j
];
sumqx
+=
w0
*
v0
*
xi
[
j
]
+
w1
*
v1
*
xi
[
QK4_NL
/
2
+
j
];
sumq2
+=
w0
*
v0
*
v0
+
w1
*
v1
*
v1
;
}
dsti
->
d
=
sumq2
>
0
?
sumqx
/
sumq2
:
d
;
}
template
<
cpy_kernel_t
cpy_blck
,
int
qk
>
static
__global__
__launch_bounds__
(
1024
)
void
cpy_f32_q
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
)
{
const
int
i
=
(
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
)
*
qk
;
if
(
i
>=
ne
)
{
return
;
}
const
int
i03
=
i
/
(
ne00
*
ne01
*
ne02
);
const
int
i02
=
(
i
-
i03
*
ne00
*
ne01
*
ne02
)
/
(
ne00
*
ne01
);
const
int
i01
=
(
i
-
i03
*
ne00
*
ne01
*
ne02
-
i02
*
ne01
*
ne00
)
/
ne00
;
const
int
i00
=
i
-
i03
*
ne00
*
ne01
*
ne02
-
i02
*
ne01
*
ne00
-
i01
*
ne00
;
const
int
x_offset
=
i00
*
nb00
+
i01
*
nb01
+
i02
*
nb02
+
i03
*
nb03
;
const
int
i13
=
i
/
(
ne10
*
ne11
*
ne12
);
const
int
i12
=
(
i
-
i13
*
ne10
*
ne11
*
ne12
)
/
(
ne10
*
ne11
);
const
int
i11
=
(
i
-
i13
*
ne10
*
ne11
*
ne12
-
i12
*
ne10
*
ne11
)
/
ne10
;
const
int
i10
=
i
-
i13
*
ne10
*
ne11
*
ne12
-
i12
*
ne10
*
ne11
-
i11
*
ne10
;
const
int
dst_offset
=
(
i10
/
qk
)
*
nb10
+
i11
*
nb11
+
i12
*
nb12
+
i13
*
nb13
;
cpy_blck
(
cx
+
x_offset
,
cdst
+
dst_offset
);
}
static
void
ggml_cpy_f16_f32_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
const
int
num_blocks
=
(
ne
+
CUDA_CPY_BLOCK_SIZE
-
1
)
/
CUDA_CPY_BLOCK_SIZE
;
cpy_f32_f16
<
cpy_1_f16_f32
><<<
num_blocks
,
CUDA_CPY_BLOCK_SIZE
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f32_f32_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
const
int
num_blocks
=
(
ne
+
CUDA_CPY_BLOCK_SIZE
-
1
)
/
CUDA_CPY_BLOCK_SIZE
;
cpy_f32_f16
<
cpy_1_f32_f32
><<<
num_blocks
,
CUDA_CPY_BLOCK_SIZE
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f32_f16_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
const
int
num_blocks
=
(
ne
+
CUDA_CPY_BLOCK_SIZE
-
1
)
/
CUDA_CPY_BLOCK_SIZE
;
cpy_f32_f16
<
cpy_1_f32_f16
><<<
num_blocks
,
CUDA_CPY_BLOCK_SIZE
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f32_q8_0_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ne
%
QK8_0
==
0
);
const
int
num_blocks
=
ne
/
QK8_0
;
cpy_f32_q
<
cpy_blck_f32_q8_0
,
QK8_0
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f32_q4_0_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ne
%
QK4_0
==
0
);
const
int
num_blocks
=
ne
/
QK4_0
;
cpy_f32_q
<
cpy_blck_f32_q4_0
,
QK4_0
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f32_q4_1_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ne
%
QK4_1
==
0
);
const
int
num_blocks
=
ne
/
QK4_1
;
cpy_f32_q
<
cpy_blck_f32_q4_1
,
QK4_1
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f32_q5_0_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ne
%
QK5_0
==
0
);
const
int
num_blocks
=
ne
/
QK5_0
;
cpy_f32_q
<
cpy_blck_f32_q5_0
,
QK5_0
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f32_q5_1_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ne
%
QK5_1
==
0
);
const
int
num_blocks
=
ne
/
QK5_1
;
cpy_f32_q
<
cpy_blck_f32_q5_1
,
QK5_1
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f32_iq4_nl_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ne
%
QK4_NL
==
0
);
const
int
num_blocks
=
ne
/
QK4_NL
;
cpy_f32_q
<
cpy_blck_f32_iq4_nl
,
QK4_NL
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f16_f16_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
const
int
num_blocks
=
(
ne
+
CUDA_CPY_BLOCK_SIZE
-
1
)
/
CUDA_CPY_BLOCK_SIZE
;
cpy_f32_f16
<
cpy_1_f16_f16
><<<
num_blocks
,
CUDA_CPY_BLOCK_SIZE
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
void
ggml_cuda_cpy
(
ggml_backend_cuda_context
&
ctx
,
const
ggml_tensor
*
src0
,
ggml_tensor
*
src1
)
{
const
int64_t
ne
=
ggml_nelements
(
src0
);
GGML_ASSERT
(
ne
==
ggml_nelements
(
src1
));
GGML_ASSERT
(
ggml_nbytes
(
src0
)
<=
INT_MAX
);
GGML_ASSERT
(
ggml_nbytes
(
src1
)
<=
INT_MAX
);
const
int64_t
ne00
=
src0
->
ne
[
0
];
const
int64_t
ne01
=
src0
->
ne
[
1
];
const
int64_t
ne02
=
src0
->
ne
[
2
];
//GGML_ASSERT(src0->ne[3] == 1);
const
int64_t
nb00
=
src0
->
nb
[
0
];
const
int64_t
nb01
=
src0
->
nb
[
1
];
const
int64_t
nb02
=
src0
->
nb
[
2
];
const
int64_t
nb03
=
src0
->
nb
[
3
];
const
int64_t
ne10
=
src1
->
ne
[
0
];
const
int64_t
ne11
=
src1
->
ne
[
1
];
const
int64_t
ne12
=
src1
->
ne
[
2
];
//GGML_ASSERT(src1->ne[3] == 1);
const
int64_t
nb10
=
src1
->
nb
[
0
];
const
int64_t
nb11
=
src1
->
nb
[
1
];
const
int64_t
nb12
=
src1
->
nb
[
2
];
const
int64_t
nb13
=
src1
->
nb
[
3
];
cudaStream_t
main_stream
=
ctx
.
stream
();
char
*
src0_ddc
=
(
char
*
)
src0
->
data
;
char
*
src1_ddc
=
(
char
*
)
src1
->
data
;
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_F32
)
{
ggml_cpy_f32_f32_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_F16
)
{
ggml_cpy_f32_f16_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q8_0
)
{
ggml_cpy_f32_q8_0_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q4_0
)
{
ggml_cpy_f32_q4_0_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q4_1
)
{
ggml_cpy_f32_q4_1_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q5_0
)
{
ggml_cpy_f32_q5_0_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_IQ4_NL
)
{
ggml_cpy_f32_iq4_nl_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q5_1
)
{
ggml_cpy_f32_q5_1_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F16
&&
src1
->
type
==
GGML_TYPE_F16
)
{
ggml_cpy_f16_f16_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F16
&&
src1
->
type
==
GGML_TYPE_F32
)
{
ggml_cpy_f16_f32_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
{
fprintf
(
stderr
,
"%s: unsupported type combination (%s to %s)
\n
"
,
__func__
,
ggml_type_name
(
src0
->
type
),
ggml_type_name
(
src1
->
type
));
GGML_ASSERT
(
false
);
}
}
void
ggml_cuda_dup
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
src0
=
dst
->
src
[
0
];
ggml_cuda_cpy
(
ctx
,
src0
,
dst
);
}
void
*
ggml_cuda_cpy_fn
(
const
ggml_tensor
*
src0
,
ggml_tensor
*
src1
)
{
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_F32
)
{
return
(
void
*
)
cpy_f32_f16
<
cpy_1_f32_f32
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_F16
)
{
return
(
void
*
)
cpy_f32_f16
<
cpy_1_f32_f16
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q8_0
)
{
return
(
void
*
)
cpy_f32_q
<
cpy_blck_f32_q8_0
,
QK8_0
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q4_0
)
{
return
(
void
*
)
cpy_f32_q
<
cpy_blck_f32_q4_0
,
QK4_0
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q4_1
)
{
return
(
void
*
)
cpy_f32_q
<
cpy_blck_f32_q4_1
,
QK4_1
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q5_0
)
{
return
(
void
*
)
cpy_f32_q
<
cpy_blck_f32_q5_0
,
QK5_0
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_IQ4_NL
)
{
return
(
void
*
)
cpy_f32_q
<
cpy_blck_f32_iq4_nl
,
QK4_NL
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q5_1
)
{
return
(
void
*
)
cpy_f32_q
<
cpy_blck_f32_q5_1
,
QK5_1
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F16
&&
src1
->
type
==
GGML_TYPE_F16
)
{
return
(
void
*
)
cpy_f32_f16
<
cpy_1_f32_f16
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F16
&&
src1
->
type
==
GGML_TYPE_F32
)
{
return
(
void
*
)
cpy_f32_f16
<
cpy_1_f16_f32
>
;
}
else
{
fprintf
(
stderr
,
"%s: unsupported type combination (%s to %s)
\n
"
,
__func__
,
ggml_type_name
(
src0
->
type
),
ggml_type_name
(
src1
->
type
));
GGML_ASSERT
(
false
);
}
}
llm/llama.cpp/ggml-cuda/cpy.cuh
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
#define CUDA_CPY_BLOCK_SIZE 32
void
ggml_cuda_cpy
(
ggml_backend_cuda_context
&
ctx
,
const
ggml_tensor
*
src0
,
ggml_tensor
*
src1
);
void
ggml_cuda_dup
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
void
*
ggml_cuda_cpy_fn
(
const
ggml_tensor
*
src0
,
ggml_tensor
*
src1
);
llm/llama.cpp/ggml-cuda/dequantize.cuh
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
static
__device__
__forceinline__
void
dequantize_q4_0
(
const
void
*
vx
,
const
int64_t
ib
,
const
int
iqs
,
dfloat2
&
v
){
const
block_q4_0
*
x
=
(
const
block_q4_0
*
)
vx
;
const
dfloat
d
=
x
[
ib
].
d
;
const
int
vui
=
x
[
ib
].
qs
[
iqs
];
v
.
x
=
vui
&
0xF
;
v
.
y
=
vui
>>
4
;
#ifdef GGML_CUDA_F16
v
=
__hsub2
(
v
,
{
8.0
f
,
8.0
f
});
v
=
__hmul2
(
v
,
{
d
,
d
});
#else
v
.
x
=
(
v
.
x
-
8.0
f
)
*
d
;
v
.
y
=
(
v
.
y
-
8.0
f
)
*
d
;
#endif // GGML_CUDA_F16
}
static
__device__
__forceinline__
void
dequantize_q4_1
(
const
void
*
vx
,
const
int64_t
ib
,
const
int
iqs
,
dfloat2
&
v
){
const
block_q4_1
*
x
=
(
const
block_q4_1
*
)
vx
;
const
dfloat
d
=
__low2half
(
x
[
ib
].
dm
);
const
dfloat
m
=
__high2half
(
x
[
ib
].
dm
);
const
int
vui
=
x
[
ib
].
qs
[
iqs
];
v
.
x
=
vui
&
0xF
;
v
.
y
=
vui
>>
4
;
#ifdef GGML_CUDA_F16
v
=
__hmul2
(
v
,
{
d
,
d
});
v
=
__hadd2
(
v
,
{
m
,
m
});
#else
v
.
x
=
(
v
.
x
*
d
)
+
m
;
v
.
y
=
(
v
.
y
*
d
)
+
m
;
#endif // GGML_CUDA_F16
}
static
__device__
__forceinline__
void
dequantize_q5_0
(
const
void
*
vx
,
const
int64_t
ib
,
const
int
iqs
,
dfloat2
&
v
){
const
block_q5_0
*
x
=
(
const
block_q5_0
*
)
vx
;
const
dfloat
d
=
x
[
ib
].
d
;
uint32_t
qh
;
memcpy
(
&
qh
,
x
[
ib
].
qh
,
sizeof
(
qh
));
const
int
xh_0
=
((
qh
>>
(
iqs
+
0
))
<<
4
)
&
0x10
;
const
int
xh_1
=
((
qh
>>
(
iqs
+
12
))
)
&
0x10
;
v
.
x
=
((
x
[
ib
].
qs
[
iqs
]
&
0xf
)
|
xh_0
);
v
.
y
=
((
x
[
ib
].
qs
[
iqs
]
>>
4
)
|
xh_1
);
#ifdef GGML_CUDA_F16
v
=
__hsub2
(
v
,
{
16.0
f
,
16.0
f
});
v
=
__hmul2
(
v
,
{
d
,
d
});
#else
v
.
x
=
(
v
.
x
-
16.0
f
)
*
d
;
v
.
y
=
(
v
.
y
-
16.0
f
)
*
d
;
#endif // GGML_CUDA_F16
}
static
__device__
__forceinline__
void
dequantize_q5_1
(
const
void
*
vx
,
const
int64_t
ib
,
const
int
iqs
,
dfloat2
&
v
){
const
block_q5_1
*
x
=
(
const
block_q5_1
*
)
vx
;
const
dfloat
d
=
__low2half
(
x
[
ib
].
dm
);
const
dfloat
m
=
__high2half
(
x
[
ib
].
dm
);
uint32_t
qh
;
memcpy
(
&
qh
,
x
[
ib
].
qh
,
sizeof
(
qh
));
const
int
xh_0
=
((
qh
>>
(
iqs
+
0
))
<<
4
)
&
0x10
;
const
int
xh_1
=
((
qh
>>
(
iqs
+
12
))
)
&
0x10
;
v
.
x
=
((
x
[
ib
].
qs
[
iqs
]
&
0xf
)
|
xh_0
);
v
.
y
=
((
x
[
ib
].
qs
[
iqs
]
>>
4
)
|
xh_1
);
#ifdef GGML_CUDA_F16
v
=
__hmul2
(
v
,
{
d
,
d
});
v
=
__hadd2
(
v
,
{
m
,
m
});
#else
v
.
x
=
(
v
.
x
*
d
)
+
m
;
v
.
y
=
(
v
.
y
*
d
)
+
m
;
#endif // GGML_CUDA_F16
}
static
__device__
__forceinline__
void
dequantize_q8_0
(
const
void
*
vx
,
const
int64_t
ib
,
const
int
iqs
,
dfloat2
&
v
){
const
block_q8_0
*
x
=
(
const
block_q8_0
*
)
vx
;
const
dfloat
d
=
x
[
ib
].
d
;
v
.
x
=
x
[
ib
].
qs
[
iqs
+
0
];
v
.
y
=
x
[
ib
].
qs
[
iqs
+
1
];
#ifdef GGML_CUDA_F16
v
=
__hmul2
(
v
,
{
d
,
d
});
#else
v
.
x
*=
d
;
v
.
y
*=
d
;
#endif // GGML_CUDA_F16
}
llm/llama.cpp/ggml-cuda/diagmask.cu
deleted
100644 → 0
View file @
97b02a89
#include "diagmask.cuh"
static
__global__
__launch_bounds__
(
1024
)
void
diag_mask_inf_f32
(
const
float
*
x
,
float
*
dst
,
const
int
ncols
,
const
int
rows_per_channel
,
const
int
n_past
)
{
const
int
col
=
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
;
const
int
row
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
col
>=
ncols
)
{
return
;
}
const
int
i
=
row
*
ncols
+
col
;
//dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
//dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
dst
[
i
]
=
x
[
i
]
-
(
col
>
n_past
+
row
%
rows_per_channel
)
*
FLT_MAX
;
}
static
void
diag_mask_inf_f32_cuda
(
const
float
*
x
,
float
*
dst
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
rows_per_channel
,
const
int
n_past
,
cudaStream_t
stream
)
{
const
dim3
block_dims
(
1
,
CUDA_DIAG_MASK_INF_BLOCK_SIZE
,
1
);
const
int
block_num_x
=
(
ncols_x
+
CUDA_DIAG_MASK_INF_BLOCK_SIZE
-
1
)
/
CUDA_DIAG_MASK_INF_BLOCK_SIZE
;
const
dim3
block_nums
(
nrows_x
,
block_num_x
,
1
);
diag_mask_inf_f32
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
x
,
dst
,
ncols_x
,
rows_per_channel
,
n_past
);
}
void
ggml_cuda_op_diag_mask_inf
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
src0
=
dst
->
src
[
0
];
const
float
*
src0_d
=
(
const
float
*
)
src0
->
data
;
float
*
dst_d
=
(
float
*
)
dst
->
data
;
cudaStream_t
stream
=
ctx
.
stream
();
GGML_ASSERT
(
src0
->
type
==
GGML_TYPE_F32
);
GGML_ASSERT
(
dst
->
type
==
GGML_TYPE_F32
);
const
int64_t
ne00
=
src0
->
ne
[
0
];
const
int64_t
ne01
=
src0
->
ne
[
1
];
const
int
nrows0
=
ggml_nrows
(
src0
);
const
int
n_past
=
((
int32_t
*
)
dst
->
op_params
)[
0
];
diag_mask_inf_f32_cuda
(
src0_d
,
dst_d
,
ne00
,
nrows0
,
ne01
,
n_past
,
stream
);
}
llm/llama.cpp/ggml-cuda/diagmask.cuh
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
void
ggml_cuda_op_diag_mask_inf
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
llm/llama.cpp/ggml-cuda/dmmv.cu
deleted
100644 → 0
View file @
97b02a89
#include "dmmv.cuh"
#include "dequantize.cuh"
#include "convert.cuh"
#ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 2
#else
static_assert
(
K_QUANTS_PER_ITERATION
==
1
||
K_QUANTS_PER_ITERATION
==
2
,
"K_QUANTS_PER_ITERATION must be 1 or 2"
);
#endif
static
__global__
__launch_bounds__
(
1024
)
void
dequantize_mul_mat_vec_q2_k
(
const
void
*
__restrict__
vx
,
const
float
*
__restrict__
yy
,
float
*
__restrict__
dst
,
const
int
ncols
,
int
nrows
)
{
static_assert
(
16
%
K_QUANTS_PER_ITERATION
==
0
,
"16 must be divisible by K_QUANTS_PER_ITERATION"
);
const
int
row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
row
>
nrows
)
return
;
const
int
num_blocks_per_row
=
ncols
/
QK_K
;
const
int
ib0
=
row
*
num_blocks_per_row
;
const
block_q2_K
*
x
=
(
const
block_q2_K
*
)
vx
+
ib0
;
float
tmp
=
0
;
// partial sum for thread in warp
const
int
tid
=
threadIdx
.
x
/
K_QUANTS_PER_ITERATION
;
// 0...31 or 0...15
const
int
ix
=
threadIdx
.
x
%
K_QUANTS_PER_ITERATION
;
// 0 or 0,1
const
int
step
=
16
/
K_QUANTS_PER_ITERATION
;
const
int
im
=
tid
/
step
;
// 0 or 1. 0 computes 0..., 1 computes 128...
const
int
in
=
tid
-
step
*
im
;
// 0...15 or 0...7
const
int
l0
=
K_QUANTS_PER_ITERATION
*
in
;
// 0...15 or 0...14 in steps of 2
const
int
q_offset
=
32
*
im
+
l0
;
const
int
s_offset
=
8
*
im
;
const
int
y_offset
=
128
*
im
+
l0
;
uint32_t
aux
[
4
];
const
uint8_t
*
d
=
(
const
uint8_t
*
)
aux
;
const
uint8_t
*
m
=
(
const
uint8_t
*
)(
aux
+
2
);
for
(
int
i
=
ix
;
i
<
num_blocks_per_row
;
i
+=
K_QUANTS_PER_ITERATION
)
{
const
float
*
y
=
yy
+
i
*
QK_K
+
y_offset
;
const
uint8_t
*
q
=
x
[
i
].
qs
+
q_offset
;
const
float
dall
=
__low2half
(
x
[
i
].
dm
);
const
float
dmin
=
__high2half
(
x
[
i
].
dm
);
const
uint32_t
*
a
=
(
const
uint32_t
*
)(
x
[
i
].
scales
+
s_offset
);
aux
[
0
]
=
a
[
0
]
&
0x0f0f0f0f
;
aux
[
1
]
=
a
[
1
]
&
0x0f0f0f0f
;
aux
[
2
]
=
(
a
[
0
]
>>
4
)
&
0x0f0f0f0f
;
aux
[
3
]
=
(
a
[
1
]
>>
4
)
&
0x0f0f0f0f
;
float
sum1
=
0
,
sum2
=
0
;
for
(
int
l
=
0
;
l
<
K_QUANTS_PER_ITERATION
;
++
l
)
{
sum1
+=
y
[
l
+
0
]
*
d
[
0
]
*
((
q
[
l
+
0
]
>>
0
)
&
3
)
+
y
[
l
+
32
]
*
d
[
2
]
*
((
q
[
l
+
0
]
>>
2
)
&
3
)
+
y
[
l
+
64
]
*
d
[
4
]
*
((
q
[
l
+
0
]
>>
4
)
&
3
)
+
y
[
l
+
96
]
*
d
[
6
]
*
((
q
[
l
+
0
]
>>
6
)
&
3
)
+
y
[
l
+
16
]
*
d
[
1
]
*
((
q
[
l
+
16
]
>>
0
)
&
3
)
+
y
[
l
+
48
]
*
d
[
3
]
*
((
q
[
l
+
16
]
>>
2
)
&
3
)
+
y
[
l
+
80
]
*
d
[
5
]
*
((
q
[
l
+
16
]
>>
4
)
&
3
)
+
y
[
l
+
112
]
*
d
[
7
]
*
((
q
[
l
+
16
]
>>
6
)
&
3
);
sum2
+=
y
[
l
+
0
]
*
m
[
0
]
+
y
[
l
+
32
]
*
m
[
2
]
+
y
[
l
+
64
]
*
m
[
4
]
+
y
[
l
+
96
]
*
m
[
6
]
+
y
[
l
+
16
]
*
m
[
1
]
+
y
[
l
+
48
]
*
m
[
3
]
+
y
[
l
+
80
]
*
m
[
5
]
+
y
[
l
+
112
]
*
m
[
7
];
}
tmp
+=
dall
*
sum1
-
dmin
*
sum2
;
}
// sum up partial sums and write back result
tmp
=
warp_reduce_sum
(
tmp
);
if
(
threadIdx
.
x
==
0
)
{
dst
[
row
]
=
tmp
;
}
}
static
__global__
__launch_bounds__
(
1024
)
void
dequantize_mul_mat_vec_q3_k
(
const
void
*
__restrict__
vx
,
const
float
*
__restrict__
yy
,
float
*
__restrict__
dst
,
const
int
ncols
,
int
nrows
)
{
const
int
row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
row
>
nrows
)
return
;
const
int
num_blocks_per_row
=
ncols
/
QK_K
;
const
int
ib0
=
row
*
num_blocks_per_row
;
const
block_q3_K
*
x
=
(
const
block_q3_K
*
)
vx
+
ib0
;
float
tmp
=
0
;
// partial sum for thread in warp
const
uint16_t
kmask1
=
0x0303
;
const
uint16_t
kmask2
=
0x0f0f
;
const
int
tid
=
threadIdx
.
x
/
K_QUANTS_PER_ITERATION
;
// 0...31 or 0...16
const
int
ix
=
threadIdx
.
x
%
K_QUANTS_PER_ITERATION
;
// 0 or 0,1
const
int
n
=
K_QUANTS_PER_ITERATION
;
// iterations in the inner loop
const
int
step
=
16
/
K_QUANTS_PER_ITERATION
;
const
int
im
=
tid
/
step
;
// 0 or 1. 0 computes 0..., 1 computes 128...
const
int
in
=
tid
-
step
*
im
;
// 0....15 or 0...7
const
uint8_t
m
=
1
<<
(
4
*
im
);
const
int
l0
=
n
*
in
;
// 0...15 or 0...14 in steps of 2
const
int
q_offset
=
32
*
im
+
l0
;
const
int
y_offset
=
128
*
im
+
l0
;
uint16_t
utmp
[
4
];
const
int8_t
*
s
=
(
const
int8_t
*
)
utmp
;
const
uint16_t
s_shift
=
4
*
im
;
for
(
int
i
=
ix
;
i
<
num_blocks_per_row
;
i
+=
K_QUANTS_PER_ITERATION
)
{
const
float
*
y
=
yy
+
i
*
QK_K
+
y_offset
;
const
uint8_t
*
q
=
x
[
i
].
qs
+
q_offset
;
const
uint8_t
*
h
=
x
[
i
].
hmask
+
l0
;
const
uint16_t
*
a
=
(
const
uint16_t
*
)
x
[
i
].
scales
;
utmp
[
0
]
=
((
a
[
0
]
>>
s_shift
)
&
kmask2
)
|
(((
a
[
4
]
>>
(
s_shift
+
0
))
&
kmask1
)
<<
4
);
utmp
[
1
]
=
((
a
[
1
]
>>
s_shift
)
&
kmask2
)
|
(((
a
[
5
]
>>
(
s_shift
+
0
))
&
kmask1
)
<<
4
);
utmp
[
2
]
=
((
a
[
2
]
>>
s_shift
)
&
kmask2
)
|
(((
a
[
4
]
>>
(
s_shift
+
2
))
&
kmask1
)
<<
4
);
utmp
[
3
]
=
((
a
[
3
]
>>
s_shift
)
&
kmask2
)
|
(((
a
[
5
]
>>
(
s_shift
+
2
))
&
kmask1
)
<<
4
);
const
float
d
=
x
[
i
].
d
;
float
sum
=
0
;
for
(
int
l
=
0
;
l
<
n
;
++
l
)
{
sum
+=
y
[
l
+
0
]
*
(
s
[
0
]
-
32
)
*
(((
q
[
l
]
>>
0
)
&
3
)
-
(
h
[
l
]
&
(
m
<<
0
)
?
0
:
4
))
+
y
[
l
+
32
]
*
(
s
[
2
]
-
32
)
*
(((
q
[
l
]
>>
2
)
&
3
)
-
(
h
[
l
]
&
(
m
<<
1
)
?
0
:
4
))
+
y
[
l
+
64
]
*
(
s
[
4
]
-
32
)
*
(((
q
[
l
]
>>
4
)
&
3
)
-
(
h
[
l
]
&
(
m
<<
2
)
?
0
:
4
))
+
y
[
l
+
96
]
*
(
s
[
6
]
-
32
)
*
(((
q
[
l
]
>>
6
)
&
3
)
-
(
h
[
l
]
&
(
m
<<
3
)
?
0
:
4
));
sum
+=
y
[
l
+
16
]
*
(
s
[
1
]
-
32
)
*
(((
q
[
l
+
16
]
>>
0
)
&
3
)
-
(
h
[
l
+
16
]
&
(
m
<<
0
)
?
0
:
4
))
+
y
[
l
+
48
]
*
(
s
[
3
]
-
32
)
*
(((
q
[
l
+
16
]
>>
2
)
&
3
)
-
(
h
[
l
+
16
]
&
(
m
<<
1
)
?
0
:
4
))
+
y
[
l
+
80
]
*
(
s
[
5
]
-
32
)
*
(((
q
[
l
+
16
]
>>
4
)
&
3
)
-
(
h
[
l
+
16
]
&
(
m
<<
2
)
?
0
:
4
))
+
y
[
l
+
112
]
*
(
s
[
7
]
-
32
)
*
(((
q
[
l
+
16
]
>>
6
)
&
3
)
-
(
h
[
l
+
16
]
&
(
m
<<
3
)
?
0
:
4
));
}
tmp
+=
d
*
sum
;
}
// sum up partial sums and write back result
tmp
=
warp_reduce_sum
(
tmp
);
if
(
threadIdx
.
x
==
0
)
{
dst
[
row
]
=
tmp
;
}
}
static
__global__
__launch_bounds__
(
1024
)
void
dequantize_mul_mat_vec_q4_k
(
const
void
*
__restrict__
vx
,
const
float
*
__restrict__
yy
,
float
*
__restrict__
dst
,
const
int
ncols
,
int
nrows
)
{
const
int
row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
row
>
nrows
)
return
;
const
int
num_blocks_per_row
=
ncols
/
QK_K
;
const
int
ib0
=
row
*
num_blocks_per_row
;
const
block_q4_K
*
x
=
(
const
block_q4_K
*
)
vx
+
ib0
;
const
uint16_t
kmask1
=
0x3f3f
;
const
uint16_t
kmask2
=
0x0f0f
;
const
uint16_t
kmask3
=
0xc0c0
;
const
int
tid
=
threadIdx
.
x
/
K_QUANTS_PER_ITERATION
;
// 0...31 or 0...16
const
int
ix
=
threadIdx
.
x
%
K_QUANTS_PER_ITERATION
;
// 0 or 0,1
const
int
step
=
8
/
K_QUANTS_PER_ITERATION
;
// 8 or 4
const
int
il
=
tid
/
step
;
// 0...3
const
int
ir
=
tid
-
step
*
il
;
// 0...7 or 0...3
const
int
n
=
2
*
K_QUANTS_PER_ITERATION
;
// 2 or 4
const
int
im
=
il
/
2
;
// 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const
int
in
=
il
%
2
;
const
int
l0
=
n
*
(
2
*
ir
+
in
);
const
int
q_offset
=
32
*
im
+
l0
;
const
int
y_offset
=
64
*
im
+
l0
;
uint16_t
aux
[
4
];
const
uint8_t
*
sc
=
(
const
uint8_t
*
)
aux
;
#if K_QUANTS_PER_ITERATION == 2
uint32_t
q32
[
4
];
const
uint8_t
*
q4
=
(
const
uint8_t
*
)
q32
;
#else
uint16_t
q16
[
4
];
const
uint8_t
*
q4
=
(
const
uint8_t
*
)
q16
;
#endif
float
tmp
=
0
;
// partial sum for thread in warp
for
(
int
i
=
ix
;
i
<
num_blocks_per_row
;
i
+=
K_QUANTS_PER_ITERATION
)
{
const
float
*
y1
=
yy
+
i
*
QK_K
+
y_offset
;
const
float
*
y2
=
y1
+
128
;
const
float
dall
=
__low2half
(
x
[
i
].
dm
);
const
float
dmin
=
__high2half
(
x
[
i
].
dm
);
const
uint16_t
*
a
=
(
const
uint16_t
*
)
x
[
i
].
scales
;
aux
[
0
]
=
a
[
im
+
0
]
&
kmask1
;
aux
[
1
]
=
a
[
im
+
2
]
&
kmask1
;
aux
[
2
]
=
((
a
[
im
+
4
]
>>
0
)
&
kmask2
)
|
((
a
[
im
+
0
]
&
kmask3
)
>>
2
);
aux
[
3
]
=
((
a
[
im
+
4
]
>>
4
)
&
kmask2
)
|
((
a
[
im
+
2
]
&
kmask3
)
>>
2
);
#if K_QUANTS_PER_ITERATION == 2
const
uint32_t
*
q1
=
(
const
uint32_t
*
)(
x
[
i
].
qs
+
q_offset
);
const
uint32_t
*
q2
=
q1
+
16
;
q32
[
0
]
=
q1
[
0
]
&
0x0f0f0f0f
;
q32
[
1
]
=
q1
[
0
]
&
0xf0f0f0f0
;
q32
[
2
]
=
q2
[
0
]
&
0x0f0f0f0f
;
q32
[
3
]
=
q2
[
0
]
&
0xf0f0f0f0
;
float4
s
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
float
smin
=
0
;
for
(
int
l
=
0
;
l
<
4
;
++
l
)
{
s
.
x
+=
y1
[
l
]
*
q4
[
l
+
0
];
s
.
y
+=
y1
[
l
+
32
]
*
q4
[
l
+
4
];
s
.
z
+=
y2
[
l
]
*
q4
[
l
+
8
];
s
.
w
+=
y2
[
l
+
32
]
*
q4
[
l
+
12
];
smin
+=
y1
[
l
]
*
sc
[
2
]
+
y1
[
l
+
32
]
*
sc
[
3
]
+
y2
[
l
]
*
sc
[
6
]
+
y2
[
l
+
32
]
*
sc
[
7
];
}
tmp
+=
dall
*
(
s
.
x
*
sc
[
0
]
+
s
.
y
*
sc
[
1
]
*
1.
f
/
16.
f
+
s
.
z
*
sc
[
4
]
+
s
.
w
*
sc
[
5
]
*
1.
f
/
16.
f
)
-
dmin
*
smin
;
#else
const
uint16_t
*
q1
=
(
const
uint16_t
*
)(
x
[
i
].
qs
+
q_offset
);
const
uint16_t
*
q2
=
q1
+
32
;
q16
[
0
]
=
q1
[
0
]
&
0x0f0f
;
q16
[
1
]
=
q1
[
0
]
&
0xf0f0
;
q16
[
2
]
=
q2
[
0
]
&
0x0f0f
;
q16
[
3
]
=
q2
[
0
]
&
0xf0f0
;
float4
s
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
float
smin
=
0
;
for
(
int
l
=
0
;
l
<
2
;
++
l
)
{
s
.
x
+=
y1
[
l
]
*
q4
[
l
+
0
];
s
.
y
+=
y1
[
l
+
32
]
*
q4
[
l
+
2
];
s
.
z
+=
y2
[
l
]
*
q4
[
l
+
4
];
s
.
w
+=
y2
[
l
+
32
]
*
q4
[
l
+
6
];
smin
+=
y1
[
l
]
*
sc
[
2
]
+
y1
[
l
+
32
]
*
sc
[
3
]
+
y2
[
l
]
*
sc
[
6
]
+
y2
[
l
+
32
]
*
sc
[
7
];
}
tmp
+=
dall
*
(
s
.
x
*
sc
[
0
]
+
s
.
y
*
sc
[
1
]
*
1.
f
/
16.
f
+
s
.
z
*
sc
[
4
]
+
s
.
w
*
sc
[
5
]
*
1.
f
/
16.
f
)
-
dmin
*
smin
;
#endif
}
// sum up partial sums and write back result
tmp
=
warp_reduce_sum
(
tmp
);
if
(
tid
==
0
)
{
dst
[
row
]
=
tmp
;
}
}
static
__global__
__launch_bounds__
(
1024
)
void
dequantize_mul_mat_vec_q5_k
(
const
void
*
__restrict__
vx
,
const
float
*
__restrict__
yy
,
float
*
__restrict__
dst
,
const
int
ncols
)
{
const
int
row
=
blockIdx
.
x
;
const
int
num_blocks_per_row
=
ncols
/
QK_K
;
const
int
ib0
=
row
*
num_blocks_per_row
;
const
block_q5_K
*
x
=
(
const
block_q5_K
*
)
vx
+
ib0
;
float
tmp
=
0
;
// partial sum for thread in warp
const
uint16_t
kmask1
=
0x3f3f
;
const
uint16_t
kmask2
=
0x0f0f
;
const
uint16_t
kmask3
=
0xc0c0
;
const
int
tid
=
threadIdx
.
x
/
2
;
// 0...15
const
int
ix
=
threadIdx
.
x
%
2
;
const
int
il
=
tid
/
4
;
// 0...3
const
int
ir
=
tid
-
4
*
il
;
// 0...3
const
int
n
=
2
;
const
int
im
=
il
/
2
;
// 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const
int
in
=
il
%
2
;
const
int
l0
=
n
*
(
2
*
ir
+
in
);
const
int
q_offset
=
32
*
im
+
l0
;
const
int
y_offset
=
64
*
im
+
l0
;
const
uint8_t
hm1
=
1
<<
(
2
*
im
);
const
uint8_t
hm2
=
hm1
<<
4
;
uint16_t
aux
[
4
];
const
uint8_t
*
sc
=
(
const
uint8_t
*
)
aux
;
uint16_t
q16
[
8
];
const
uint8_t
*
q4
=
(
const
uint8_t
*
)
q16
;
for
(
int
i
=
ix
;
i
<
num_blocks_per_row
;
i
+=
2
)
{
const
uint8_t
*
ql1
=
x
[
i
].
qs
+
q_offset
;
const
uint8_t
*
qh
=
x
[
i
].
qh
+
l0
;
const
float
*
y1
=
yy
+
i
*
QK_K
+
y_offset
;
const
float
*
y2
=
y1
+
128
;
const
float
dall
=
__low2half
(
x
[
i
].
dm
);
const
float
dmin
=
__high2half
(
x
[
i
].
dm
);
const
uint16_t
*
a
=
(
const
uint16_t
*
)
x
[
i
].
scales
;
aux
[
0
]
=
a
[
im
+
0
]
&
kmask1
;
aux
[
1
]
=
a
[
im
+
2
]
&
kmask1
;
aux
[
2
]
=
((
a
[
im
+
4
]
>>
0
)
&
kmask2
)
|
((
a
[
im
+
0
]
&
kmask3
)
>>
2
);
aux
[
3
]
=
((
a
[
im
+
4
]
>>
4
)
&
kmask2
)
|
((
a
[
im
+
2
]
&
kmask3
)
>>
2
);
float4
sum
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
float
smin
=
0
;
const
uint16_t
*
q1
=
(
const
uint16_t
*
)
ql1
;
const
uint16_t
*
q2
=
q1
+
32
;
q16
[
0
]
=
q1
[
0
]
&
0x0f0f
;
q16
[
1
]
=
q1
[
8
]
&
0x0f0f
;
q16
[
2
]
=
(
q1
[
0
]
>>
4
)
&
0x0f0f
;
q16
[
3
]
=
(
q1
[
8
]
>>
4
)
&
0x0f0f
;
q16
[
4
]
=
q2
[
0
]
&
0x0f0f
;
q16
[
5
]
=
q2
[
8
]
&
0x0f0f
;
q16
[
6
]
=
(
q2
[
0
]
>>
4
)
&
0x0f0f
;
q16
[
7
]
=
(
q2
[
8
]
>>
4
)
&
0x0f0f
;
for
(
int
l
=
0
;
l
<
n
;
++
l
)
{
sum
.
x
+=
y1
[
l
+
0
]
*
(
q4
[
l
+
0
]
+
(
qh
[
l
+
0
]
&
(
hm1
<<
0
)
?
16
:
0
))
+
y1
[
l
+
16
]
*
(
q4
[
l
+
2
]
+
(
qh
[
l
+
16
]
&
(
hm1
<<
0
)
?
16
:
0
));
sum
.
y
+=
y1
[
l
+
32
]
*
(
q4
[
l
+
4
]
+
(
qh
[
l
+
0
]
&
(
hm1
<<
1
)
?
16
:
0
))
+
y1
[
l
+
48
]
*
(
q4
[
l
+
6
]
+
(
qh
[
l
+
16
]
&
(
hm1
<<
1
)
?
16
:
0
));
sum
.
z
+=
y2
[
l
+
0
]
*
(
q4
[
l
+
8
]
+
(
qh
[
l
+
0
]
&
(
hm2
<<
0
)
?
16
:
0
))
+
y2
[
l
+
16
]
*
(
q4
[
l
+
10
]
+
(
qh
[
l
+
16
]
&
(
hm2
<<
0
)
?
16
:
0
));
sum
.
w
+=
y2
[
l
+
32
]
*
(
q4
[
l
+
12
]
+
(
qh
[
l
+
0
]
&
(
hm2
<<
1
)
?
16
:
0
))
+
y2
[
l
+
48
]
*
(
q4
[
l
+
14
]
+
(
qh
[
l
+
16
]
&
(
hm2
<<
1
)
?
16
:
0
));
smin
+=
(
y1
[
l
]
+
y1
[
l
+
16
])
*
sc
[
2
]
+
(
y1
[
l
+
32
]
+
y1
[
l
+
48
])
*
sc
[
3
]
+
(
y2
[
l
]
+
y2
[
l
+
16
])
*
sc
[
6
]
+
(
y2
[
l
+
32
]
+
y2
[
l
+
48
])
*
sc
[
7
];
}
tmp
+=
dall
*
(
sum
.
x
*
sc
[
0
]
+
sum
.
y
*
sc
[
1
]
+
sum
.
z
*
sc
[
4
]
+
sum
.
w
*
sc
[
5
])
-
dmin
*
smin
;
}
// sum up partial sums and write back result
tmp
=
warp_reduce_sum
(
tmp
);
if
(
threadIdx
.
x
==
0
)
{
dst
[
row
]
=
tmp
;
}
}
static
__global__
__launch_bounds__
(
1024
)
void
dequantize_mul_mat_vec_q6_k
(
const
void
*
__restrict__
vx
,
const
float
*
__restrict__
yy
,
float
*
__restrict__
dst
,
const
int
ncols
,
int
nrows
)
{
static_assert
(
16
%
K_QUANTS_PER_ITERATION
==
0
,
"16 must be divisible by K_QUANTS_PER_ITERATION"
);
const
int
row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
row
>
nrows
)
return
;
const
int
num_blocks_per_row
=
ncols
/
QK_K
;
const
int
ib0
=
row
*
num_blocks_per_row
;
const
block_q6_K
*
x
=
(
const
block_q6_K
*
)
vx
+
ib0
;
const
int
tid
=
threadIdx
.
x
/
K_QUANTS_PER_ITERATION
;
// 0...31 or 0...16
const
int
ix
=
threadIdx
.
x
%
K_QUANTS_PER_ITERATION
;
// 0 or 0, 1
const
int
step
=
16
/
K_QUANTS_PER_ITERATION
;
// 16 or 8
const
int
im
=
tid
/
step
;
// 0 or 1. 0 computes 0..., 1 computes 128...
const
int
in
=
tid
-
step
*
im
;
// 0...15 or 0...7
#if K_QUANTS_PER_ITERATION == 1
const
int
l0
=
K_QUANTS_PER_ITERATION
*
in
;
// 0...15
const
int
is
=
0
;
#else
const
int
l0
=
4
*
in
;
// 0, 4, 8, ..., 28
const
int
is
=
in
/
4
;
#endif
const
int
ql_offset
=
64
*
im
+
l0
;
const
int
qh_offset
=
32
*
im
+
l0
;
const
int
s_offset
=
8
*
im
+
is
;
const
int
y_offset
=
128
*
im
+
l0
;
float
tmp
=
0
;
// partial sum for thread in warp
for
(
int
i
=
ix
;
i
<
num_blocks_per_row
;
i
+=
K_QUANTS_PER_ITERATION
)
{
const
float
*
y
=
yy
+
i
*
QK_K
+
y_offset
;
const
uint8_t
*
ql
=
x
[
i
].
ql
+
ql_offset
;
const
uint8_t
*
qh
=
x
[
i
].
qh
+
qh_offset
;
const
int8_t
*
s
=
x
[
i
].
scales
+
s_offset
;
const
float
d
=
x
[
i
].
d
;
#if K_QUANTS_PER_ITERATION == 1
float
sum
=
y
[
0
]
*
s
[
0
]
*
d
*
((
int8_t
)((
ql
[
0
]
&
0xF
)
|
((
qh
[
0
]
&
0x03
)
<<
4
))
-
32
)
+
y
[
16
]
*
s
[
1
]
*
d
*
((
int8_t
)((
ql
[
16
]
&
0xF
)
|
((
qh
[
16
]
&
0x03
)
<<
4
))
-
32
)
+
y
[
32
]
*
s
[
2
]
*
d
*
((
int8_t
)((
ql
[
32
]
&
0xF
)
|
((
qh
[
0
]
&
0x0c
)
<<
2
))
-
32
)
+
y
[
48
]
*
s
[
3
]
*
d
*
((
int8_t
)((
ql
[
48
]
&
0xF
)
|
((
qh
[
16
]
&
0x0c
)
<<
2
))
-
32
)
+
y
[
64
]
*
s
[
4
]
*
d
*
((
int8_t
)((
ql
[
0
]
>>
4
)
|
((
qh
[
0
]
&
0x30
)
>>
0
))
-
32
)
+
y
[
80
]
*
s
[
5
]
*
d
*
((
int8_t
)((
ql
[
16
]
>>
4
)
|
((
qh
[
16
]
&
0x30
)
>>
0
))
-
32
)
+
y
[
96
]
*
s
[
6
]
*
d
*
((
int8_t
)((
ql
[
32
]
>>
4
)
|
((
qh
[
0
]
&
0xc0
)
>>
2
))
-
32
)
+
y
[
112
]
*
s
[
7
]
*
d
*
((
int8_t
)((
ql
[
48
]
>>
4
)
|
((
qh
[
16
]
&
0xc0
)
>>
2
))
-
32
);
tmp
+=
sum
;
#else
float
sum
=
0
;
for
(
int
l
=
0
;
l
<
4
;
++
l
)
{
sum
+=
y
[
l
+
0
]
*
s
[
0
]
*
d
*
((
int8_t
)((
ql
[
l
+
0
]
&
0xF
)
|
(((
qh
[
l
]
>>
0
)
&
3
)
<<
4
))
-
32
)
+
y
[
l
+
32
]
*
s
[
2
]
*
d
*
((
int8_t
)((
ql
[
l
+
32
]
&
0xF
)
|
(((
qh
[
l
]
>>
2
)
&
3
)
<<
4
))
-
32
)
+
y
[
l
+
64
]
*
s
[
4
]
*
d
*
((
int8_t
)((
ql
[
l
+
0
]
>>
4
)
|
(((
qh
[
l
]
>>
4
)
&
3
)
<<
4
))
-
32
)
+
y
[
l
+
96
]
*
s
[
6
]
*
d
*
((
int8_t
)((
ql
[
l
+
32
]
>>
4
)
|
(((
qh
[
l
]
>>
6
)
&
3
)
<<
4
))
-
32
);
}
tmp
+=
sum
;
#endif
}
// sum up partial sums and write back result
tmp
=
warp_reduce_sum
(
tmp
);
if
(
tid
==
0
)
{
dst
[
row
]
=
tmp
;
}
}
static
__device__
void
convert_f16
(
const
void
*
vx
,
const
int64_t
ib
,
const
int
iqs
,
dfloat2
&
v
){
const
half
*
x
=
(
const
half
*
)
vx
;
// automatic half -> float type cast if dfloat == float
v
.
x
=
x
[
ib
+
iqs
+
0
];
v
.
y
=
x
[
ib
+
iqs
+
1
];
}
template
<
int
qk
,
int
qr
,
dequantize_kernel_t
dequantize_kernel
>
static
__global__
__launch_bounds__
(
1024
)
void
dequantize_mul_mat_vec
(
const
void
*
__restrict__
vx
,
const
dfloat
*
__restrict__
y
,
float
*
__restrict__
dst
,
const
int
ncols
,
const
int
nrows
)
{
// qk = quantized weights per x block
// qr = number of quantized weights per data value in x block
const
int64_t
row
=
(
int64_t
)
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
row
>=
nrows
)
{
return
;
}
const
int
tid
=
threadIdx
.
x
;
const
int
iter_stride
=
2
*
GGML_CUDA_DMMV_X
;
const
int
vals_per_iter
=
iter_stride
/
WARP_SIZE
;
// num quantized vals per thread and i iter
const
int
y_offset
=
qr
==
1
?
1
:
qk
/
2
;
// partial sum for each thread
#ifdef GGML_CUDA_F16
half2
tmp
=
{
0.0
f
,
0.0
f
};
// two sums for f16 to take advantage of half2 intrinsics
#else
float
tmp
=
0.0
f
;
#endif // GGML_CUDA_F16
for
(
int
i
=
0
;
i
<
ncols
;
i
+=
iter_stride
)
{
const
int
col
=
i
+
vals_per_iter
*
tid
;
const
int64_t
ib
=
((
int64_t
)
row
*
ncols
+
col
)
/
qk
;
// x block index
const
int
iqs
=
(
col
%
qk
)
/
qr
;
// x quant index
const
int
iybs
=
col
-
col
%
qk
;
// y block start index
// processing >2 values per i iter is faster for fast GPUs
#pragma unroll
for
(
int
j
=
0
;
j
<
vals_per_iter
;
j
+=
2
)
{
// process 2 vals per j iter
// dequantize
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
dfloat2
v
;
dequantize_kernel
(
vx
,
ib
,
iqs
+
j
/
qr
,
v
);
// matrix multiplication
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
#ifdef GGML_CUDA_F16
tmp
+=
__hmul2
(
v
,
{
y
[
iybs
+
iqs
+
j
/
qr
+
0
],
y
[
iybs
+
iqs
+
j
/
qr
+
y_offset
]
});
#else
tmp
+=
v
.
x
*
y
[
iybs
+
iqs
+
j
/
qr
+
0
];
tmp
+=
v
.
y
*
y
[
iybs
+
iqs
+
j
/
qr
+
y_offset
];
#endif // GGML_CUDA_F16
}
}
// sum up partial sums and write back result
tmp
=
warp_reduce_sum
(
tmp
);
if
(
tid
==
0
)
{
#ifdef GGML_CUDA_F16
dst
[
row
]
=
tmp
.
x
+
tmp
.
y
;
#else
dst
[
row
]
=
tmp
;
#endif // GGML_CUDA_F16
}
}
static
void
dequantize_mul_mat_vec_q4_0_cuda
(
const
void
*
vx
,
const
dfloat
*
y
,
float
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ncols
%
GGML_CUDA_DMMV_X
==
0
);
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
dequantize_mul_mat_vec
<
QK4_0
,
QR4_0
,
dequantize_q4_0
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
y
,
dst
,
ncols
,
nrows
);
}
static
void
dequantize_mul_mat_vec_q4_1_cuda
(
const
void
*
vx
,
const
dfloat
*
y
,
float
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ncols
%
GGML_CUDA_DMMV_X
==
0
);
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
dequantize_mul_mat_vec
<
QK4_1
,
QR4_1
,
dequantize_q4_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
y
,
dst
,
ncols
,
nrows
);
}
static
void
dequantize_mul_mat_vec_q5_0_cuda
(
const
void
*
vx
,
const
dfloat
*
y
,
float
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ncols
%
GGML_CUDA_DMMV_X
==
0
);
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
dequantize_mul_mat_vec
<
QK5_0
,
QR5_0
,
dequantize_q5_0
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
y
,
dst
,
ncols
,
nrows
);
}
static
void
dequantize_mul_mat_vec_q5_1_cuda
(
const
void
*
vx
,
const
dfloat
*
y
,
float
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ncols
%
GGML_CUDA_DMMV_X
==
0
);
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
dequantize_mul_mat_vec
<
QK5_1
,
QR5_1
,
dequantize_q5_1
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
y
,
dst
,
ncols
,
nrows
);
}
static
void
dequantize_mul_mat_vec_q8_0_cuda
(
const
void
*
vx
,
const
dfloat
*
y
,
float
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ncols
%
GGML_CUDA_DMMV_X
==
0
);
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
dequantize_mul_mat_vec
<
QK8_0
,
QR8_0
,
dequantize_q8_0
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
y
,
dst
,
ncols
,
nrows
);
}
static
void
dequantize_mul_mat_vec_q2_K_cuda
(
const
void
*
vx
,
const
float
*
y
,
float
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ncols
%
QK_K
==
0
);
const
int
ny
=
2
;
// very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
const
int
block_num_y
=
(
nrows
+
ny
-
1
)
/
ny
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
32
,
ny
,
1
);
dequantize_mul_mat_vec_q2_k
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
y
,
dst
,
ncols
,
nrows
);
}
static
void
dequantize_mul_mat_vec_q3_K_cuda
(
const
void
*
vx
,
const
float
*
y
,
float
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ncols
%
QK_K
==
0
);
const
int
ny
=
2
/
K_QUANTS_PER_ITERATION
;
const
int
block_num_y
=
(
nrows
+
ny
-
1
)
/
ny
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
32
,
ny
,
1
);
dequantize_mul_mat_vec_q3_k
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
y
,
dst
,
ncols
,
nrows
);
}
static
void
dequantize_mul_mat_vec_q4_K_cuda
(
const
void
*
vx
,
const
float
*
y
,
float
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ncols
%
QK_K
==
0
);
const
int
ny
=
2
/
K_QUANTS_PER_ITERATION
;
const
int
block_num_y
=
(
nrows
+
ny
-
1
)
/
ny
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
32
,
ny
,
1
);
dequantize_mul_mat_vec_q4_k
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
y
,
dst
,
ncols
,
nrows
);
}
static
void
dequantize_mul_mat_vec_q5_K_cuda
(
const
void
*
vx
,
const
float
*
y
,
float
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ncols
%
QK_K
==
0
);
const
dim3
block_dims
(
32
,
1
,
1
);
dequantize_mul_mat_vec_q5_k
<<<
nrows
,
block_dims
,
0
,
stream
>>>
(
vx
,
y
,
dst
,
ncols
);
}
static
void
dequantize_mul_mat_vec_q6_K_cuda
(
const
void
*
vx
,
const
float
*
y
,
float
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ncols
%
QK_K
==
0
);
const
int
ny
=
2
/
K_QUANTS_PER_ITERATION
;
const
int
block_num_y
=
(
nrows
+
ny
-
1
)
/
ny
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
32
,
ny
,
1
);
dequantize_mul_mat_vec_q6_k
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
y
,
dst
,
ncols
,
nrows
);
}
static
void
convert_mul_mat_vec_f16_cuda
(
const
void
*
vx
,
const
dfloat
*
y
,
float
*
dst
,
const
int
ncols
,
const
int
nrows
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ncols
%
GGML_CUDA_DMMV_X
==
0
);
const
int
block_num_y
=
(
nrows
+
GGML_CUDA_MMV_Y
-
1
)
/
GGML_CUDA_MMV_Y
;
const
dim3
block_nums
(
block_num_y
,
1
,
1
);
const
dim3
block_dims
(
WARP_SIZE
,
GGML_CUDA_MMV_Y
,
1
);
dequantize_mul_mat_vec
<
1
,
1
,
convert_f16
>
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
vx
,
y
,
dst
,
ncols
,
nrows
);
}
void
ggml_cuda_op_dequantize_mul_mat_vec
(
ggml_backend_cuda_context
&
ctx
,
const
ggml_tensor
*
src0
,
const
ggml_tensor
*
src1
,
ggml_tensor
*
dst
,
const
char
*
src0_dd_i
,
const
float
*
src1_ddf_i
,
const
char
*
src1_ddq_i
,
float
*
dst_dd_i
,
const
int64_t
row_low
,
const
int64_t
row_high
,
const
int64_t
src1_ncols
,
const
int64_t
src1_padded_row_size
,
cudaStream_t
stream
)
{
GGML_UNUSED
(
ctx
);
const
int64_t
ne00
=
src0
->
ne
[
0
];
const
int64_t
row_diff
=
row_high
-
row_low
;
GGML_ASSERT
(
src1
->
type
==
GGML_TYPE_F32
);
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
#ifdef GGML_CUDA_F16
ggml_cuda_pool_alloc
<
half
>
src1_dfloat_a
(
ctx
.
pool
());
half
*
src1_dfloat
=
nullptr
;
// dfloat == half
bool
src1_convert_f16
=
src0
->
type
==
GGML_TYPE_Q4_0
||
src0
->
type
==
GGML_TYPE_Q4_1
||
src0
->
type
==
GGML_TYPE_Q5_0
||
src0
->
type
==
GGML_TYPE_Q5_1
||
src0
->
type
==
GGML_TYPE_Q8_0
||
src0
->
type
==
GGML_TYPE_F16
;
if
(
src1_convert_f16
)
{
src1_dfloat
=
src1_dfloat_a
.
alloc
(
ne00
);
const
to_fp16_cuda_t
to_fp16_cuda
=
ggml_get_to_fp16_cuda
(
src1
->
type
);
GGML_ASSERT
(
to_fp16_cuda
!=
nullptr
);
to_fp16_cuda
(
src1_ddf_i
,
src1_dfloat
,
ne00
,
stream
);
}
#else
const
dfloat
*
src1_dfloat
=
(
const
dfloat
*
)
src1_ddf_i
;
// dfloat == float, no conversion
#endif // GGML_CUDA_F16
switch
(
src0
->
type
)
{
case
GGML_TYPE_Q4_0
:
dequantize_mul_mat_vec_q4_0_cuda
(
src0_dd_i
,
src1_dfloat
,
dst_dd_i
,
ne00
,
row_diff
,
stream
);
break
;
case
GGML_TYPE_Q4_1
:
dequantize_mul_mat_vec_q4_1_cuda
(
src0_dd_i
,
src1_dfloat
,
dst_dd_i
,
ne00
,
row_diff
,
stream
);
break
;
case
GGML_TYPE_Q5_0
:
dequantize_mul_mat_vec_q5_0_cuda
(
src0_dd_i
,
src1_dfloat
,
dst_dd_i
,
ne00
,
row_diff
,
stream
);
break
;
case
GGML_TYPE_Q5_1
:
dequantize_mul_mat_vec_q5_1_cuda
(
src0_dd_i
,
src1_dfloat
,
dst_dd_i
,
ne00
,
row_diff
,
stream
);
break
;
case
GGML_TYPE_Q8_0
:
dequantize_mul_mat_vec_q8_0_cuda
(
src0_dd_i
,
src1_dfloat
,
dst_dd_i
,
ne00
,
row_diff
,
stream
);
break
;
case
GGML_TYPE_Q2_K
:
dequantize_mul_mat_vec_q2_K_cuda
(
src0_dd_i
,
src1_ddf_i
,
dst_dd_i
,
ne00
,
row_diff
,
stream
);
break
;
case
GGML_TYPE_Q3_K
:
dequantize_mul_mat_vec_q3_K_cuda
(
src0_dd_i
,
src1_ddf_i
,
dst_dd_i
,
ne00
,
row_diff
,
stream
);
break
;
case
GGML_TYPE_Q4_K
:
dequantize_mul_mat_vec_q4_K_cuda
(
src0_dd_i
,
src1_ddf_i
,
dst_dd_i
,
ne00
,
row_diff
,
stream
);
break
;
case
GGML_TYPE_Q5_K
:
dequantize_mul_mat_vec_q5_K_cuda
(
src0_dd_i
,
src1_ddf_i
,
dst_dd_i
,
ne00
,
row_diff
,
stream
);
break
;
case
GGML_TYPE_Q6_K
:
dequantize_mul_mat_vec_q6_K_cuda
(
src0_dd_i
,
src1_ddf_i
,
dst_dd_i
,
ne00
,
row_diff
,
stream
);
break
;
case
GGML_TYPE_F16
:
convert_mul_mat_vec_f16_cuda
(
src0_dd_i
,
src1_dfloat
,
dst_dd_i
,
ne00
,
row_diff
,
stream
);
break
;
default:
GGML_ASSERT
(
false
);
break
;
}
GGML_UNUSED
(
src1
);
GGML_UNUSED
(
dst
);
GGML_UNUSED
(
src1_ddq_i
);
GGML_UNUSED
(
src1_ncols
);
GGML_UNUSED
(
src1_padded_row_size
);
}
llm/llama.cpp/ggml-cuda/dmmv.cuh
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
// dmmv = dequantize_mul_mat_vec
// TODO: remove this?
#ifndef GGML_CUDA_DMMV_X
#define GGML_CUDA_DMMV_X 32
#endif
#ifndef GGML_CUDA_MMV_Y
#define GGML_CUDA_MMV_Y 1
#endif
void
ggml_cuda_op_dequantize_mul_mat_vec
(
ggml_backend_cuda_context
&
ctx
,
const
ggml_tensor
*
src0
,
const
ggml_tensor
*
src1
,
ggml_tensor
*
dst
,
const
char
*
src0_dd_i
,
const
float
*
src1_ddf_i
,
const
char
*
src1_ddq_i
,
float
*
dst_dd_i
,
const
int64_t
row_low
,
const
int64_t
row_high
,
const
int64_t
src1_ncols
,
const
int64_t
src1_padded_row_size
,
cudaStream_t
stream
);
llm/llama.cpp/ggml-cuda/fattn-common.cuh
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
#include <cstdint>
#define FATTN_KQ_STRIDE 256
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
typedef
void
(
*
fattn_kernel_t
)(
const
char
*
__restrict__
Q
,
const
char
*
__restrict__
K
,
const
char
*
__restrict__
V
,
const
char
*
__restrict__
mask
,
float
*
__restrict__
dst
,
float2
*
__restrict__
dst_meta
,
const
float
scale
,
const
float
max_bias
,
const
float
m0
,
const
float
m1
,
const
uint32_t
n_head_log2
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
ne03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
ne13
,
const
int
ne31
,
const
int
nb31
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
const
int
ne0
,
const
int
ne1
,
const
int
ne2
,
const
int
ne3
);
template
<
int
D
,
int
parallel_blocks
>
// D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__
(
D
,
1
)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static
__global__
__launch_bounds__
(
1024
)
void
flash_attn_combine_results
(
const
float
*
__restrict__
VKQ_parts
,
const
float2
*
__restrict__
VKQ_meta
,
float
*
__restrict__
dst
)
{
VKQ_parts
+=
parallel_blocks
*
D
*
gridDim
.
y
*
blockIdx
.
x
;
VKQ_meta
+=
parallel_blocks
*
gridDim
.
y
*
blockIdx
.
x
;
dst
+=
D
*
gridDim
.
y
*
blockIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
__builtin_assume
(
tid
<
D
);
__shared__
float2
meta
[
parallel_blocks
];
if
(
tid
<
2
*
parallel_blocks
)
{
((
float
*
)
meta
)[
threadIdx
.
x
]
=
((
const
float
*
)
VKQ_meta
)
[
blockIdx
.
y
*
(
2
*
parallel_blocks
)
+
tid
];
}
__syncthreads
();
float
kqmax
=
meta
[
0
].
x
;
#pragma unroll
for
(
int
l
=
1
;
l
<
parallel_blocks
;
++
l
)
{
kqmax
=
max
(
kqmax
,
meta
[
l
].
x
);
}
float
VKQ_numerator
=
0.0
f
;
float
VKQ_denominator
=
0.0
f
;
#pragma unroll
for
(
int
l
=
0
;
l
<
parallel_blocks
;
++
l
)
{
const
float
diff
=
meta
[
l
].
x
-
kqmax
;
const
float
KQ_max_scale
=
expf
(
diff
);
const
uint32_t
ftz_mask
=
0xFFFFFFFF
*
(
diff
>
SOFTMAX_FTZ_THRESHOLD
);
*
((
uint32_t
*
)
&
KQ_max_scale
)
&=
ftz_mask
;
VKQ_numerator
+=
KQ_max_scale
*
VKQ_parts
[
l
*
gridDim
.
y
*
D
+
blockIdx
.
y
*
D
+
tid
];
VKQ_denominator
+=
KQ_max_scale
*
meta
[
l
].
y
;
}
dst
[
blockIdx
.
y
*
D
+
tid
]
=
VKQ_numerator
/
VKQ_denominator
;
}
template
<
int
D
,
int
parallel_blocks
>
void
launch_fattn
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
,
fattn_kernel_t
fattn_kernel
,
int
nwarps
,
int
cols_per_block
)
{
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
const
ggml_tensor
*
K
=
dst
->
src
[
1
];
const
ggml_tensor
*
V
=
dst
->
src
[
2
];
const
ggml_tensor
*
mask
=
dst
->
src
[
3
];
ggml_tensor
*
KQV
=
dst
;
GGML_ASSERT
(
Q
->
type
==
GGML_TYPE_F32
);
GGML_ASSERT
(
K
->
type
==
GGML_TYPE_F16
);
GGML_ASSERT
(
V
->
type
==
GGML_TYPE_F16
);
GGML_ASSERT
(
KQV
->
type
==
GGML_TYPE_F32
);
GGML_ASSERT
(
!
mask
||
mask
->
type
==
GGML_TYPE_F16
);
GGML_ASSERT
(
!
mask
||
mask
->
ne
[
1
]
>=
GGML_PAD
(
Q
->
ne
[
1
],
16
)
&&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"
);
GGML_ASSERT
(
K
->
ne
[
1
]
%
FATTN_KQ_STRIDE
==
0
&&
"Incorrect KV cache padding."
);
ggml_cuda_pool
&
pool
=
ctx
.
pool
();
cudaStream_t
main_stream
=
ctx
.
stream
();
ggml_cuda_pool_alloc
<
float
>
dst_tmp
(
pool
);
ggml_cuda_pool_alloc
<
float2
>
dst_tmp_meta
(
pool
);
if
(
parallel_blocks
>
1
)
{
dst_tmp
.
alloc
(
parallel_blocks
*
ggml_nelements
(
KQV
));
dst_tmp_meta
.
alloc
(
parallel_blocks
*
ggml_nrows
(
KQV
));
}
const
dim3
block_dim
(
WARP_SIZE
,
nwarps
,
1
);
const
dim3
blocks_num
(
parallel_blocks
*
((
Q
->
ne
[
1
]
+
cols_per_block
-
1
)
/
cols_per_block
),
Q
->
ne
[
2
],
Q
->
ne
[
3
]);
const
int
shmem
=
0
;
float
scale
=
1.0
f
;
float
max_bias
=
0.0
f
;
memcpy
(
&
scale
,
(
float
*
)
KQV
->
op_params
+
0
,
sizeof
(
float
));
memcpy
(
&
max_bias
,
(
float
*
)
KQV
->
op_params
+
1
,
sizeof
(
float
));
const
uint32_t
n_head
=
Q
->
ne
[
2
];
const
uint32_t
n_head_log2
=
1u
<<
(
uint32_t
)
floorf
(
log2f
((
float
)
n_head
));
const
float
m0
=
powf
(
2.0
f
,
-
(
max_bias
)
/
n_head_log2
);
const
float
m1
=
powf
(
2.0
f
,
-
(
max_bias
/
2.0
f
)
/
n_head_log2
);
fattn_kernel
<<<
blocks_num
,
block_dim
,
shmem
,
main_stream
>>>
(
(
const
char
*
)
Q
->
data
,
(
const
char
*
)
K
->
data
,
(
const
char
*
)
V
->
data
,
mask
?
((
const
char
*
)
mask
->
data
)
:
nullptr
,
(
parallel_blocks
)
==
1
?
(
float
*
)
KQV
->
data
:
dst_tmp
.
ptr
,
dst_tmp_meta
.
ptr
,
scale
,
max_bias
,
m0
,
m1
,
n_head_log2
,
Q
->
ne
[
0
],
Q
->
ne
[
1
],
Q
->
ne
[
2
],
Q
->
ne
[
3
],
K
->
ne
[
0
],
K
->
ne
[
1
],
K
->
ne
[
2
],
K
->
ne
[
3
],
mask
?
mask
->
ne
[
1
]
:
0
,
mask
?
mask
->
nb
[
1
]
:
0
,
Q
->
nb
[
1
],
Q
->
nb
[
2
],
Q
->
nb
[
3
],
K
->
nb
[
1
],
K
->
nb
[
2
],
K
->
nb
[
3
],
KQV
->
ne
[
0
],
KQV
->
ne
[
1
],
KQV
->
ne
[
2
],
KQV
->
ne
[
3
]
);
CUDA_CHECK
(
cudaGetLastError
());
if
((
parallel_blocks
)
==
1
)
{
return
;
}
const
dim3
block_dim_combine
(
D
,
1
,
1
);
const
dim3
blocks_num_combine
(
Q
->
ne
[
1
],
blocks_num
.
y
,
blocks_num
.
z
);
const
int
shmem_combine
=
0
;
flash_attn_combine_results
<
D
,
parallel_blocks
>
<<<
blocks_num_combine
,
block_dim_combine
,
shmem_combine
,
main_stream
>>>
(
dst_tmp
.
ptr
,
dst_tmp_meta
.
ptr
,
(
float
*
)
KQV
->
data
);
CUDA_CHECK
(
cudaGetLastError
());
}
llm/llama.cpp/ggml-cuda/fattn-tile-f16.cu
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile-f16.cuh"
#define FATTN_KQ_STRIDE_TILE_F16 64
template
<
int
D
,
int
ncols
,
int
nwarps
,
int
parallel_blocks
>
// D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__
(
nwarps
*
WARP_SIZE
,
1
)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static
__global__
__launch_bounds__
(
1024
)
void
flash_attn_tile_ext_f16
(
const
char
*
__restrict__
Q
,
const
char
*
__restrict__
K
,
const
char
*
__restrict__
V
,
const
char
*
__restrict__
mask
,
float
*
__restrict__
dst
,
float2
*
__restrict__
dst_meta
,
const
float
scale
,
const
float
max_bias
,
const
float
m0
,
const
float
m1
,
const
uint32_t
n_head_log2
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
ne03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
ne13
,
const
int
ne31
,
const
int
nb31
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
const
int
ne0
,
const
int
ne1
,
const
int
ne2
,
const
int
ne3
)
{
#if FP16_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const
int
ic0
=
(
blockIdx
.
x
/
parallel_blocks
)
*
ncols
;
// Index of the Q/QKV column to work on.
const
int
ip
=
blockIdx
.
x
%
parallel_blocks
;
// Index in group of blocks running for the same column in parallel.
const
int
gqa_ratio
=
ne02
/
ne12
;
// With grouped query attention there are > 1 Q matrices per K, V matrix.
const
float2
*
Q_f2
=
(
const
float2
*
)
(
Q
+
nb02
*
blockIdx
.
y
+
nb01
*
ic0
);
const
half2
*
K_h2
=
(
const
half2
*
)
(
K
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
const
half2
*
V_h2
=
(
const
half2
*
)
(
V
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
// K and V have same shape
const
half
*
maskh
=
(
const
half
*
)
mask
+
ne11
*
ic0
;
const
int
stride_KV2
=
nb11
/
sizeof
(
half2
);
const
float
slopef
=
get_alibi_slope
(
max_bias
,
blockIdx
.
y
,
n_head_log2
,
m0
,
m1
);
const
half
slopeh
=
__float2half
(
slopef
);
static_assert
(
D
%
(
2
*
WARP_SIZE
)
==
0
,
"D not divisible by 2*WARP_SIZE == 64."
);
__shared__
half
KQ
[
ncols
*
FATTN_KQ_STRIDE_TILE_F16
];
half2
*
KQ2
=
(
half2
*
)
KQ
;
__shared__
half2
KV_tmp
[
FATTN_KQ_STRIDE_TILE_F16
][
D
/
2
+
1
];
// Pad D to avoid memory bank conflicts.
half
kqmax
[
ncols
/
nwarps
];
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
kqmax
[
j0
/
nwarps
]
=
-
HALF_MAX_HALF
;
}
half2
kqsum
[
ncols
/
nwarps
]
=
{{
0.0
f
,
0.0
f
}};
half2
VKQ
[
ncols
/
nwarps
][(
D
/
2
)
/
WARP_SIZE
]
=
{{{
0.0
f
,
0.0
f
}}};
// Convert Q to half2 and store in registers:
__shared__
half2
Q_h2
[
ncols
][
D
/
2
];
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
const
int
j
=
j0
+
threadIdx
.
y
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
const
float2
tmp
=
ic0
+
j
<
ne01
?
Q_f2
[
j
*
(
nb01
/
sizeof
(
float2
))
+
i
]
:
make_float2
(
0.0
f
,
0.0
f
);
Q_h2
[
j
][
i
]
=
make_half2
(
scale
,
scale
)
*
make_half2
(
tmp
.
x
,
tmp
.
y
);
}
}
__syncthreads
();
const
int
k_start
=
parallel_blocks
==
1
?
0
:
ip
*
FATTN_KQ_STRIDE_TILE_F16
;
for
(
int
k_VKQ_0
=
k_start
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
parallel_blocks
*
FATTN_KQ_STRIDE_TILE_F16
)
{
// Calculate KQ tile and keep track of new maximum KQ values:
half
kqmax_new
[
ncols
/
nwarps
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
/
nwarps
;
++
j
)
{
kqmax_new
[
j
]
=
kqmax
[
j
];
}
#pragma unroll
for
(
int
i_KQ_0
=
0
;
i_KQ_0
<
FATTN_KQ_STRIDE_TILE_F16
;
i_KQ_0
+=
nwarps
)
{
const
int
i_KQ
=
i_KQ_0
+
threadIdx
.
y
;
#pragma unroll
for
(
int
k_KQ_0
=
0
;
k_KQ_0
<
D
/
2
;
k_KQ_0
+=
WARP_SIZE
)
{
const
int
k_KQ
=
k_KQ_0
+
threadIdx
.
x
;
KV_tmp
[
i_KQ
][
k_KQ
]
=
K_h2
[(
k_VKQ_0
+
i_KQ
)
*
stride_KV2
+
k_KQ
];
}
}
__syncthreads
();
half2
sum2
[
FATTN_KQ_STRIDE_TILE_F16
/
WARP_SIZE
][
ncols
/
nwarps
]
=
{{{
0.0
f
,
0.0
f
}}};
#pragma unroll
for
(
int
k_KQ
=
0
;
k_KQ
<
D
/
2
;
++
k_KQ
)
{
half2
K_k
[
FATTN_KQ_STRIDE_TILE_F16
/
WARP_SIZE
];
half2
Q_k
[
ncols
/
nwarps
];
#pragma unroll
for
(
int
i_KQ_0
=
0
;
i_KQ_0
<
FATTN_KQ_STRIDE_TILE_F16
;
i_KQ_0
+=
WARP_SIZE
)
{
const
int
i_KQ
=
i_KQ_0
+
threadIdx
.
x
;
K_k
[
i_KQ_0
/
WARP_SIZE
]
=
KV_tmp
[
i_KQ
][
k_KQ
];
}
#pragma unroll
for
(
int
j_KQ_0
=
0
;
j_KQ_0
<
ncols
;
j_KQ_0
+=
nwarps
)
{
const
int
j_KQ
=
j_KQ_0
+
threadIdx
.
y
;
Q_k
[
j_KQ_0
/
nwarps
]
=
Q_h2
[
j_KQ
][
k_KQ
];
}
#pragma unroll
for
(
int
i_KQ_0
=
0
;
i_KQ_0
<
FATTN_KQ_STRIDE_TILE_F16
;
i_KQ_0
+=
WARP_SIZE
)
{
#pragma unroll
for
(
int
j_KQ_0
=
0
;
j_KQ_0
<
ncols
;
j_KQ_0
+=
nwarps
)
{
sum2
[
i_KQ_0
/
WARP_SIZE
][
j_KQ_0
/
nwarps
]
+=
K_k
[
i_KQ_0
/
WARP_SIZE
]
*
Q_k
[
j_KQ_0
/
nwarps
];
}
}
}
#pragma unroll
for
(
int
i_KQ_0
=
0
;
i_KQ_0
<
FATTN_KQ_STRIDE_TILE_F16
;
i_KQ_0
+=
WARP_SIZE
)
{
const
int
i_KQ
=
i_KQ_0
+
threadIdx
.
x
;
#pragma unroll
for
(
int
j_KQ_0
=
0
;
j_KQ_0
<
ncols
;
j_KQ_0
+=
nwarps
)
{
const
int
j_KQ
=
j_KQ_0
+
threadIdx
.
y
;
half
sum
=
__low2half
(
sum2
[
i_KQ_0
/
WARP_SIZE
][
j_KQ_0
/
nwarps
])
+
__high2half
(
sum2
[
i_KQ_0
/
WARP_SIZE
][
j_KQ_0
/
nwarps
]);
sum
+=
mask
?
slopeh
*
maskh
[
j_KQ
*
ne11
+
k_VKQ_0
+
i_KQ
]
:
__float2half
(
0.0
f
);
kqmax_new
[
j_KQ_0
/
nwarps
]
=
ggml_cuda_hmax
(
kqmax_new
[
j_KQ_0
/
nwarps
],
sum
);
KQ
[
j_KQ
*
FATTN_KQ_STRIDE_TILE_F16
+
i_KQ
]
=
sum
;
}
}
__syncthreads
();
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
const
int
j
=
j0
+
threadIdx
.
y
;
kqmax_new
[
j0
/
nwarps
]
=
warp_reduce_max
(
kqmax_new
[
j0
/
nwarps
]);
const
half2
KQ_max_scale
=
__half2half2
(
hexp
(
kqmax
[
j0
/
nwarps
]
-
kqmax_new
[
j0
/
nwarps
]));
kqmax
[
j0
/
nwarps
]
=
kqmax_new
[
j0
/
nwarps
];
#pragma unroll
for
(
int
i0
=
0
;
i0
<
FATTN_KQ_STRIDE_TILE_F16
/
2
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
const
half2
diff
=
KQ2
[
j
*
(
FATTN_KQ_STRIDE_TILE_F16
/
2
)
+
i
]
-
__half2half2
(
kqmax
[
j0
/
nwarps
]);
const
half2
val
=
h2exp
(
diff
);
kqsum
[
j0
/
nwarps
]
=
kqsum
[
j0
/
nwarps
]
*
KQ_max_scale
+
val
;
KQ2
[
j
*
(
FATTN_KQ_STRIDE_TILE_F16
/
2
)
+
i
]
=
val
;
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
VKQ
[
j0
/
nwarps
][
i0
/
WARP_SIZE
]
*=
KQ_max_scale
;
}
}
__syncthreads
();
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE_TILE_F16
;
k0
+=
nwarps
)
{
const
int
k
=
k0
+
threadIdx
.
y
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
KV_tmp
[
k
][
i
]
=
V_h2
[(
k_VKQ_0
+
k
)
*
stride_KV2
+
i
];
}
}
__syncthreads
();
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE_TILE_F16
;
k0
+=
2
)
{
half2
V_k
[(
D
/
2
)
/
WARP_SIZE
][
2
];
half2
KQ_k
[
ncols
/
nwarps
];
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
V_k
[
i0
/
WARP_SIZE
][
0
]
=
KV_tmp
[
k0
+
0
][
i
];
V_k
[
i0
/
WARP_SIZE
][
1
]
=
KV_tmp
[
k0
+
1
][
i
];
}
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
const
int
j
=
j0
+
threadIdx
.
y
;
KQ_k
[
j0
/
nwarps
]
=
KQ2
[
j
*
(
FATTN_KQ_STRIDE_TILE_F16
/
2
)
+
k0
/
2
];
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
VKQ
[
j0
/
nwarps
][
i0
/
WARP_SIZE
]
+=
V_k
[
i0
/
WARP_SIZE
][
0
]
*
__low2half2
(
KQ_k
[
j0
/
nwarps
]);
VKQ
[
j0
/
nwarps
][
i0
/
WARP_SIZE
]
+=
V_k
[
i0
/
WARP_SIZE
][
1
]
*
__high2half2
(
KQ_k
[
j0
/
nwarps
]);
}
}
}
__syncthreads
();
}
#pragma unroll
for
(
int
j_VKQ_0
=
0
;
j_VKQ_0
<
ncols
;
j_VKQ_0
+=
nwarps
)
{
const
int
j_VKQ
=
j_VKQ_0
+
threadIdx
.
y
;
if
(
ic0
+
j_VKQ
>=
ne01
)
{
return
;
}
half
kqsum_j
=
__low2half
(
kqsum
[
j_VKQ_0
/
nwarps
])
+
__high2half
(
kqsum
[
j_VKQ_0
/
nwarps
]);
kqsum_j
=
warp_reduce_sum
(
kqsum_j
);
#pragma unroll
for
(
int
i00
=
0
;
i00
<
D
;
i00
+=
2
*
WARP_SIZE
)
{
const
int
i0
=
i00
+
2
*
threadIdx
.
x
;
half2
dst_val
=
VKQ
[
j_VKQ_0
/
nwarps
][
i0
/
(
2
*
WARP_SIZE
)];
if
(
parallel_blocks
==
1
)
{
dst_val
/=
__half2half2
(
kqsum_j
);
}
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
parallel_blocks
+
ip
;
dst
[
j_dst
*
D
*
gridDim
.
y
+
D
*
blockIdx
.
y
+
i0
+
0
]
=
__low2float
(
dst_val
);
dst
[
j_dst
*
D
*
gridDim
.
y
+
D
*
blockIdx
.
y
+
i0
+
1
]
=
__high2float
(
dst_val
);
}
if
(
parallel_blocks
!=
1
&&
threadIdx
.
x
==
0
)
{
dst_meta
[(
ic0
+
j_VKQ
)
*
gridDim
.
y
*
parallel_blocks
+
blockIdx
.
y
*
parallel_blocks
+
ip
]
=
make_float2
(
kqmax
[
j_VKQ_0
/
nwarps
],
kqsum_j
);
}
}
#else
NO_DEVICE_CODE
;
#endif // FP16_AVAILABLE
}
template
<
int
cols_per_block
,
int
parallel_blocks
>
void
launch_fattn_tile_f16_64_128
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
switch
(
Q
->
ne
[
0
])
{
case
64
:
{
constexpr
int
D
=
64
;
constexpr
int
nwarps
=
8
;
fattn_kernel_t
fattn_kernel
=
flash_attn_tile_ext_f16
<
D
,
cols_per_block
,
nwarps
,
parallel_blocks
>
;
launch_fattn
<
D
,
parallel_blocks
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
cols_per_block
);
}
break
;
case
128
:
{
constexpr
int
D
=
128
;
constexpr
int
nwarps
=
8
;
fattn_kernel_t
fattn_kernel
=
flash_attn_tile_ext_f16
<
D
,
cols_per_block
,
nwarps
,
parallel_blocks
>
;
launch_fattn
<
D
,
parallel_blocks
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
cols_per_block
);
}
break
;
default:
{
GGML_ASSERT
(
false
&&
"FlashAttention without tensor cores only supports head sizes 64 and 128."
);
}
break
;
}
}
void
ggml_cuda_flash_attn_ext_tile_f16
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
KQV
=
dst
;
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
const
int32_t
precision
=
KQV
->
op_params
[
2
];
GGML_ASSERT
(
precision
==
GGML_PREC_DEFAULT
);
if
(
Q
->
ne
[
1
]
<=
16
)
{
constexpr
int
cols_per_block
=
16
;
constexpr
int
parallel_blocks
=
4
;
launch_fattn_tile_f16_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
return
;
}
if
(
Q
->
ne
[
1
]
<=
32
)
{
constexpr
int
cols_per_block
=
32
;
constexpr
int
parallel_blocks
=
4
;
launch_fattn_tile_f16_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
return
;
}
constexpr
int
cols_per_block
=
32
;
constexpr
int
parallel_blocks
=
1
;
launch_fattn_tile_f16_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
}
llm/llama.cpp/ggml-cuda/fattn-tile-f16.cuh
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
void
ggml_cuda_flash_attn_ext_tile_f16
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
llm/llama.cpp/ggml-cuda/fattn-tile-f32.cu
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile-f32.cuh"
#define FATTN_KQ_STRIDE_TILE_F32 32
template
<
int
D
,
int
ncols
,
int
nwarps
,
int
parallel_blocks
>
// D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__
(
nwarps
*
WARP_SIZE
,
1
)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static
__global__
__launch_bounds__
(
1024
)
void
flash_attn_tile_ext_f32
(
const
char
*
__restrict__
Q
,
const
char
*
__restrict__
K
,
const
char
*
__restrict__
V
,
const
char
*
__restrict__
mask
,
float
*
__restrict__
dst
,
float2
*
__restrict__
dst_meta
,
const
float
scale
,
const
float
max_bias
,
const
float
m0
,
const
float
m1
,
const
uint32_t
n_head_log2
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
ne03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
ne13
,
const
int
ne31
,
const
int
nb31
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
const
int
ne0
,
const
int
ne1
,
const
int
ne2
,
const
int
ne3
)
{
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const
int
ic0
=
(
blockIdx
.
x
/
parallel_blocks
)
*
ncols
;
// Index of the Q/QKV column to work on.
const
int
ip
=
blockIdx
.
x
%
parallel_blocks
;
// Index in group of blocks running for the same column in parallel.
const
int
gqa_ratio
=
ne02
/
ne12
;
// With grouped query attention there are > 1 Q matrices per K, V matrix.
const
float2
*
Q_f2
=
(
const
float2
*
)
(
Q
+
nb02
*
blockIdx
.
y
+
nb01
*
ic0
);
const
half2
*
K_h2
=
(
const
half2
*
)
(
K
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
const
half2
*
V_h2
=
(
const
half2
*
)
(
V
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
// K and V have same shape
const
half
*
maskh
=
(
const
half
*
)
mask
+
ne11
*
ic0
;
const
int
stride_KV2
=
nb11
/
sizeof
(
half2
);
const
float
slope
=
get_alibi_slope
(
max_bias
,
blockIdx
.
y
,
n_head_log2
,
m0
,
m1
);
static_assert
(
D
%
(
2
*
WARP_SIZE
)
==
0
,
"D not divisible by 2*WARP_SIZE == 64."
);
__shared__
float
KQ
[
ncols
*
FATTN_KQ_STRIDE_TILE_F32
];
__shared__
float
KV_tmp
[
FATTN_KQ_STRIDE_TILE_F32
][
D
+
1
];
// Pad D to avoid memory bank conflicts.
float2
*
KV_tmp2
=
(
float2
*
)
KV_tmp
;
float
kqmax
[
ncols
/
nwarps
];
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
kqmax
[
j0
/
nwarps
]
=
-
FLT_MAX
/
2.0
f
;
}
float
kqsum
[
ncols
/
nwarps
]
=
{
0.0
f
};
float2
VKQ
[
ncols
/
nwarps
][(
D
/
2
)
/
WARP_SIZE
]
=
{{{
0.0
f
,
0.0
f
}}};
// Convert Q to half2 and store in registers:
__shared__
float
Q_f
[
ncols
][
D
];
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
const
int
j
=
j0
+
threadIdx
.
y
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
;
i0
+=
2
*
WARP_SIZE
)
{
float2
tmp
=
ic0
+
j
<
ne01
?
Q_f2
[
j
*
(
nb01
/
sizeof
(
float2
))
+
i0
/
2
+
threadIdx
.
x
]
:
make_float2
(
0.0
f
,
0.0
f
);
Q_f
[
j
][
i0
+
0
*
WARP_SIZE
+
threadIdx
.
x
]
=
tmp
.
x
*
scale
;
Q_f
[
j
][
i0
+
1
*
WARP_SIZE
+
threadIdx
.
x
]
=
tmp
.
y
*
scale
;
}
}
__syncthreads
();
const
int
k_start
=
parallel_blocks
==
1
?
0
:
ip
*
FATTN_KQ_STRIDE_TILE_F32
;
for
(
int
k_VKQ_0
=
k_start
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
parallel_blocks
*
FATTN_KQ_STRIDE_TILE_F32
)
{
// Calculate KQ tile and keep track of new maximum KQ values:
float
kqmax_new
[
ncols
/
nwarps
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
/
nwarps
;
++
j
)
{
kqmax_new
[
j
]
=
kqmax
[
j
];
}
#pragma unroll
for
(
int
i_KQ_0
=
0
;
i_KQ_0
<
FATTN_KQ_STRIDE_TILE_F32
;
i_KQ_0
+=
nwarps
)
{
const
int
i_KQ
=
i_KQ_0
+
threadIdx
.
y
;
#pragma unroll
for
(
int
k_KQ_0
=
0
;
k_KQ_0
<
D
;
k_KQ_0
+=
2
*
WARP_SIZE
)
{
const
half2
tmp
=
K_h2
[(
k_VKQ_0
+
i_KQ
)
*
stride_KV2
+
k_KQ_0
/
2
+
threadIdx
.
x
];
KV_tmp
[
i_KQ
][
k_KQ_0
+
0
*
WARP_SIZE
+
threadIdx
.
x
]
=
__low2float
(
tmp
);
KV_tmp
[
i_KQ
][
k_KQ_0
+
1
*
WARP_SIZE
+
threadIdx
.
x
]
=
__high2float
(
tmp
);
}
}
__syncthreads
();
float
sum
[
FATTN_KQ_STRIDE_TILE_F32
/
WARP_SIZE
][
ncols
/
nwarps
]
=
{{
0.0
f
}};
#pragma unroll
for
(
int
k_KQ
=
0
;
k_KQ
<
D
;
++
k_KQ
)
{
float
K_k
[
FATTN_KQ_STRIDE_TILE_F32
/
WARP_SIZE
];
float
Q_k
[
ncols
/
nwarps
];
#pragma unroll
for
(
int
i_KQ_0
=
0
;
i_KQ_0
<
FATTN_KQ_STRIDE_TILE_F32
;
i_KQ_0
+=
WARP_SIZE
)
{
const
int
i_KQ
=
i_KQ_0
+
threadIdx
.
x
;
K_k
[
i_KQ_0
/
WARP_SIZE
]
=
KV_tmp
[
i_KQ
][
k_KQ
];
}
#pragma unroll
for
(
int
j_KQ_0
=
0
;
j_KQ_0
<
ncols
;
j_KQ_0
+=
nwarps
)
{
const
int
j_KQ
=
j_KQ_0
+
threadIdx
.
y
;
Q_k
[
j_KQ_0
/
nwarps
]
=
Q_f
[
j_KQ
][
k_KQ
];
}
#pragma unroll
for
(
int
i_KQ_0
=
0
;
i_KQ_0
<
FATTN_KQ_STRIDE_TILE_F32
;
i_KQ_0
+=
WARP_SIZE
)
{
#pragma unroll
for
(
int
j_KQ_0
=
0
;
j_KQ_0
<
ncols
;
j_KQ_0
+=
nwarps
)
{
sum
[
i_KQ_0
/
WARP_SIZE
][
j_KQ_0
/
nwarps
]
+=
K_k
[
i_KQ_0
/
WARP_SIZE
]
*
Q_k
[
j_KQ_0
/
nwarps
];
}
}
}
#pragma unroll
for
(
int
i_KQ_0
=
0
;
i_KQ_0
<
FATTN_KQ_STRIDE_TILE_F32
;
i_KQ_0
+=
WARP_SIZE
)
{
const
int
i_KQ
=
i_KQ_0
+
threadIdx
.
x
;
#pragma unroll
for
(
int
j_KQ_0
=
0
;
j_KQ_0
<
ncols
;
j_KQ_0
+=
nwarps
)
{
const
int
j_KQ
=
j_KQ_0
+
threadIdx
.
y
;
sum
[
i_KQ_0
/
WARP_SIZE
][
j_KQ_0
/
nwarps
]
+=
mask
?
slope
*
__half2float
(
maskh
[
j_KQ
*
ne11
+
k_VKQ_0
+
i_KQ
])
:
0.0
f
;
kqmax_new
[
j_KQ_0
/
nwarps
]
=
fmaxf
(
kqmax_new
[
j_KQ_0
/
nwarps
],
sum
[
i_KQ_0
/
WARP_SIZE
][
j_KQ_0
/
nwarps
]);
KQ
[
j_KQ
*
FATTN_KQ_STRIDE_TILE_F32
+
i_KQ
]
=
sum
[
i_KQ_0
/
WARP_SIZE
][
j_KQ_0
/
nwarps
];
}
}
__syncthreads
();
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
const
int
j
=
j0
+
threadIdx
.
y
;
kqmax_new
[
j0
/
nwarps
]
=
warp_reduce_max
(
kqmax_new
[
j0
/
nwarps
]);
const
float
KQ_max_scale
=
expf
(
kqmax
[
j0
/
nwarps
]
-
kqmax_new
[
j0
/
nwarps
]);
kqmax
[
j0
/
nwarps
]
=
kqmax_new
[
j0
/
nwarps
];
float
kqsum_add
=
0.0
f
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
FATTN_KQ_STRIDE_TILE_F32
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
const
float
diff
=
KQ
[
j
*
FATTN_KQ_STRIDE_TILE_F32
+
i
]
-
kqmax
[
j0
/
nwarps
];
const
float
val
=
expf
(
diff
);
kqsum_add
+=
val
;
KQ
[
j
*
FATTN_KQ_STRIDE_TILE_F32
+
i
]
=
val
;
}
kqsum
[
j0
/
nwarps
]
=
kqsum
[
j0
/
nwarps
]
*
KQ_max_scale
+
kqsum_add
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
VKQ
[
j0
/
nwarps
][
i0
/
WARP_SIZE
].
x
*=
KQ_max_scale
;
VKQ
[
j0
/
nwarps
][
i0
/
WARP_SIZE
].
y
*=
KQ_max_scale
;
}
}
__syncthreads
();
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE_TILE_F32
;
k0
+=
nwarps
)
{
const
int
k
=
k0
+
threadIdx
.
y
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
KV_tmp2
[
k
*
(
D
/
2
)
+
i
].
x
=
__low2float
(
V_h2
[(
k_VKQ_0
+
k
)
*
stride_KV2
+
i
]);
KV_tmp2
[
k
*
(
D
/
2
)
+
i
].
y
=
__high2float
(
V_h2
[(
k_VKQ_0
+
k
)
*
stride_KV2
+
i
]);
}
}
__syncthreads
();
#pragma unroll
for
(
int
k
=
0
;
k
<
FATTN_KQ_STRIDE_TILE_F32
;
++
k
)
{
float2
V_k
[(
D
/
2
)
/
WARP_SIZE
];
float
KQ_k
[
ncols
/
nwarps
];
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
V_k
[
i0
/
WARP_SIZE
]
=
KV_tmp2
[
k
*
(
D
/
2
)
+
i
];
}
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
const
int
j
=
j0
+
threadIdx
.
y
;
KQ_k
[
j0
/
nwarps
]
=
KQ
[
j
*
FATTN_KQ_STRIDE_TILE_F32
+
k
];
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
VKQ
[
j0
/
nwarps
][
i0
/
WARP_SIZE
].
x
+=
V_k
[
i0
/
WARP_SIZE
].
x
*
KQ_k
[
j0
/
nwarps
];
VKQ
[
j0
/
nwarps
][
i0
/
WARP_SIZE
].
y
+=
V_k
[
i0
/
WARP_SIZE
].
y
*
KQ_k
[
j0
/
nwarps
];
}
}
}
__syncthreads
();
}
#pragma unroll
for
(
int
j_VKQ_0
=
0
;
j_VKQ_0
<
ncols
;
j_VKQ_0
+=
nwarps
)
{
const
int
j_VKQ
=
j_VKQ_0
+
threadIdx
.
y
;
if
(
ic0
+
j_VKQ
>=
ne01
)
{
return
;
}
float
kqsum_j
=
kqsum
[
j_VKQ_0
/
nwarps
];
kqsum_j
=
warp_reduce_sum
(
kqsum_j
);
#pragma unroll
for
(
int
i00
=
0
;
i00
<
D
;
i00
+=
2
*
WARP_SIZE
)
{
const
int
i0
=
i00
+
2
*
threadIdx
.
x
;
float2
dst_val
=
VKQ
[
j_VKQ_0
/
nwarps
][
i0
/
(
2
*
WARP_SIZE
)];
if
(
parallel_blocks
==
1
)
{
dst_val
.
x
/=
kqsum_j
;
dst_val
.
y
/=
kqsum_j
;
}
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
parallel_blocks
+
ip
;
dst
[
j_dst
*
D
*
gridDim
.
y
+
D
*
blockIdx
.
y
+
i0
+
0
]
=
dst_val
.
x
;
dst
[
j_dst
*
D
*
gridDim
.
y
+
D
*
blockIdx
.
y
+
i0
+
1
]
=
dst_val
.
y
;
}
if
(
parallel_blocks
!=
1
&&
threadIdx
.
x
==
0
)
{
dst_meta
[(
ic0
+
j_VKQ
)
*
gridDim
.
y
*
parallel_blocks
+
blockIdx
.
y
*
parallel_blocks
+
ip
]
=
make_float2
(
kqmax
[
j_VKQ_0
/
nwarps
],
kqsum_j
);
}
}
}
template
<
int
cols_per_block
,
int
parallel_blocks
>
void
launch_fattn_tile_f32_64_128
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
switch
(
Q
->
ne
[
0
])
{
case
64
:
{
constexpr
int
D
=
64
;
constexpr
int
nwarps
=
8
;
fattn_kernel_t
fattn_kernel
=
flash_attn_tile_ext_f32
<
D
,
cols_per_block
,
nwarps
,
parallel_blocks
>
;
launch_fattn
<
D
,
parallel_blocks
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
cols_per_block
);
}
break
;
case
128
:
{
constexpr
int
D
=
128
;
constexpr
int
nwarps
=
8
;
fattn_kernel_t
fattn_kernel
=
flash_attn_tile_ext_f32
<
D
,
cols_per_block
,
nwarps
,
parallel_blocks
>
;
launch_fattn
<
D
,
parallel_blocks
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
cols_per_block
);
}
break
;
default:
{
GGML_ASSERT
(
false
&&
"FlashAttention without tensor cores only supports head sizes 64 and 128."
);
}
break
;
}
}
void
ggml_cuda_flash_attn_ext_tile_f32
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
if
(
Q
->
ne
[
1
]
<=
16
)
{
constexpr
int
cols_per_block
=
16
;
constexpr
int
parallel_blocks
=
4
;
launch_fattn_tile_f32_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
return
;
}
if
(
Q
->
ne
[
1
]
<=
32
)
{
constexpr
int
cols_per_block
=
32
;
constexpr
int
parallel_blocks
=
4
;
launch_fattn_tile_f32_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
return
;
}
constexpr
int
cols_per_block
=
32
;
constexpr
int
parallel_blocks
=
1
;
launch_fattn_tile_f32_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
}
llm/llama.cpp/ggml-cuda/fattn-tile-f32.cuh
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
void
ggml_cuda_flash_attn_ext_tile_f32
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
llm/llama.cpp/ggml-cuda/fattn-vec-f16.cu
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-vec-f16.cuh"
template
<
int
D
,
int
ncols
,
int
parallel_blocks
>
// D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__
(
D
,
1
)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static
__global__
__launch_bounds__
(
1024
)
void
flash_attn_vec_ext_f16
(
const
char
*
__restrict__
Q
,
const
char
*
__restrict__
K
,
const
char
*
__restrict__
V
,
const
char
*
__restrict__
mask
,
float
*
__restrict__
dst
,
float2
*
__restrict__
dst_meta
,
const
float
scale
,
const
float
max_bias
,
const
float
m0
,
const
float
m1
,
const
uint32_t
n_head_log2
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
ne03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
ne13
,
const
int
ne31
,
const
int
nb31
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
const
int
ne0
,
const
int
ne1
,
const
int
ne2
,
const
int
ne3
)
{
#if FP16_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const
int
ic0
=
(
blockIdx
.
x
/
parallel_blocks
)
*
ncols
;
// Index of the Q/QKV column to work on.
const
int
ip
=
blockIdx
.
x
%
parallel_blocks
;
// Index in group of blocks running for the same column in parallel.
const
int
gqa_ratio
=
ne02
/
ne12
;
// With grouped query attention there are > 1 Q matrices per K, V matrix.
const
float2
*
Q_f2
=
(
const
float2
*
)
(
Q
+
nb02
*
blockIdx
.
y
+
nb01
*
ic0
);
const
half2
*
K_h2
=
(
const
half2
*
)
(
K
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
const
half
*
V_h
=
(
const
half
*
)
(
V
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
// K and V have same shape
const
half
*
maskh
=
(
const
half
*
)
mask
+
ne11
*
ic0
;
const
int
stride_KV
=
nb11
/
sizeof
(
half
);
const
int
stride_KV2
=
nb11
/
sizeof
(
half2
);
const
float
slopef
=
get_alibi_slope
(
max_bias
,
blockIdx
.
y
,
n_head_log2
,
m0
,
m1
);
const
half
slopeh
=
__float2half
(
slopef
);
static_assert
(
D
%
(
2
*
WARP_SIZE
)
==
0
,
"D not divisible by 2*WARP_SIZE == 64."
);
constexpr
int
nwarps
=
D
/
WARP_SIZE
;
const
int
tid
=
WARP_SIZE
*
threadIdx
.
y
+
threadIdx
.
x
;
__builtin_assume
(
tid
<
D
);
__shared__
half
KQ
[
ncols
*
D
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
KQ
[
j
*
D
+
tid
]
=
-
HALF_MAX_HALF
;
}
half2
*
KQ2
=
(
half2
*
)
KQ
;
half
kqmax
[
ncols
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
kqmax
[
j
]
=
-
HALF_MAX_HALF
;
}
half
kqsum
[
ncols
]
=
{
0.0
f
};
__shared__
half
kqmax_shared
[
ncols
][
WARP_SIZE
];
__shared__
half
kqsum_shared
[
ncols
][
WARP_SIZE
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
if
(
threadIdx
.
y
==
0
)
{
kqmax_shared
[
j
][
threadIdx
.
x
]
=
-
HALF_MAX_HALF
;
kqsum_shared
[
j
][
threadIdx
.
x
]
=
0.0
f
;
}
}
__syncthreads
();
// Convert Q to half2 and store in registers:
half2
Q_h2
[
ncols
][
D
/
(
2
*
WARP_SIZE
)];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
const
float2
tmp
=
ncols
<=
2
||
ic0
+
j
<
ne01
?
Q_f2
[
j
*
(
nb01
/
sizeof
(
float2
))
+
i
]
:
make_float2
(
0.0
f
,
0.0
f
);
Q_h2
[
j
][
i0
/
WARP_SIZE
]
=
make_half2
(
scale
,
scale
)
*
make_half2
(
tmp
.
x
,
tmp
.
y
);
}
}
half2
VKQ
[
ncols
]
=
{{
0.0
f
,
0.0
f
}};
const
int
k_start
=
parallel_blocks
==
1
?
0
:
ip
*
D
;
for
(
int
k_VKQ_0
=
k_start
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
parallel_blocks
*
D
)
{
// Calculate KQ tile and keep track of new maximum KQ values:
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
half
kqmax_new
=
kqmax
[
0
];
half
kqmax_new_arr
[
ncols
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
kqmax_new_arr
[
j
]
=
kqmax
[
j
];
}
#pragma unroll
for
(
int
i_KQ_0
=
0
;
i_KQ_0
<
D
;
i_KQ_0
+=
nwarps
)
{
const
int
i_KQ
=
i_KQ_0
+
threadIdx
.
y
;
if
((
i_KQ_0
+
nwarps
>
D
&&
i_KQ
>=
D
)
||
(
FATTN_KQ_STRIDE
%
D
!=
0
&&
k_VKQ_0
+
i_KQ
>=
ne11
))
{
break
;
}
half2
sum2
[
ncols
]
=
{{
0.0
f
,
0.0
f
}};
#pragma unroll
for
(
int
k_KQ_0
=
0
;
k_KQ_0
<
D
/
2
;
k_KQ_0
+=
WARP_SIZE
)
{
const
int
k_KQ
=
k_KQ_0
+
threadIdx
.
x
;
const
half2
K_ik
=
K_h2
[(
k_VKQ_0
+
i_KQ
)
*
stride_KV2
+
k_KQ
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
sum2
[
j
]
+=
K_ik
*
Q_h2
[
j
][
k_KQ_0
/
WARP_SIZE
];
}
}
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
sum2
[
j
]
=
warp_reduce_sum
(
sum2
[
j
]);
half
sum
=
__low2half
(
sum2
[
j
])
+
__high2half
(
sum2
[
j
]);
sum
+=
mask
?
slopeh
*
maskh
[
j
*
ne11
+
k_VKQ_0
+
i_KQ
]
:
__float2half
(
0.0
f
);
if
(
ncols
==
1
)
{
kqmax_new
=
ggml_cuda_hmax
(
kqmax_new
,
sum
);
}
else
{
kqmax_new_arr
[
j
]
=
ggml_cuda_hmax
(
kqmax_new_arr
[
j
],
sum
);
}
if
(
threadIdx
.
x
==
0
)
{
KQ
[
j
*
D
+
i_KQ
]
=
sum
;
}
}
}
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
half
kqmax_new_j
=
ncols
==
1
?
kqmax_new
:
kqmax_new_arr
[
j
];
kqmax_new_j
=
warp_reduce_max
(
kqmax_new_j
);
if
(
threadIdx
.
x
==
0
)
{
kqmax_shared
[
j
][
threadIdx
.
y
]
=
kqmax_new_j
;
}
}
__syncthreads
();
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
half
kqmax_new_j
=
kqmax_shared
[
j
][
threadIdx
.
x
];
kqmax_new_j
=
warp_reduce_max
(
kqmax_new_j
);
const
half
KQ_max_scale
=
hexp
(
kqmax
[
j
]
-
kqmax_new_j
);
kqmax
[
j
]
=
kqmax_new_j
;
const
half
val
=
hexp
(
KQ
[
j
*
D
+
tid
]
-
kqmax
[
j
]);
kqsum
[
j
]
=
kqsum
[
j
]
*
KQ_max_scale
+
val
;
KQ
[
j
*
D
+
tid
]
=
val
;
VKQ
[
j
]
*=
__half2half2
(
KQ_max_scale
);
}
__syncthreads
();
#pragma unroll
for
(
int
k0
=
0
;
k0
<
D
;
k0
+=
2
)
{
if
(
FATTN_KQ_STRIDE
%
D
!=
0
&&
k_VKQ_0
+
k0
>=
ne11
)
{
break
;
}
half2
V_k
;
reinterpret_cast
<
half
&>
(
V_k
.
x
)
=
V_h
[(
k_VKQ_0
+
k0
+
0
)
*
stride_KV
+
tid
];
reinterpret_cast
<
half
&>
(
V_k
.
y
)
=
V_h
[(
k_VKQ_0
+
k0
+
1
)
*
stride_KV
+
tid
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
VKQ
[
j
]
+=
V_k
*
KQ2
[
j
*
(
D
/
2
)
+
k0
/
2
];
}
}
__syncthreads
();
}
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
kqsum
[
j
]
=
warp_reduce_sum
(
kqsum
[
j
]);
if
(
threadIdx
.
x
==
0
)
{
kqsum_shared
[
j
][
threadIdx
.
y
]
=
kqsum
[
j
];
}
}
__syncthreads
();
#pragma unroll
for
(
int
j_VKQ
=
0
;
j_VKQ
<
ncols
;
++
j_VKQ
)
{
if
(
ncols
>
2
&&
ic0
+
j_VKQ
>=
ne01
)
{
break
;
}
kqsum
[
j_VKQ
]
=
kqsum_shared
[
j_VKQ
][
threadIdx
.
x
];
kqsum
[
j_VKQ
]
=
warp_reduce_sum
(
kqsum
[
j_VKQ
]);
half
dst_val
=
(
__low2half
(
VKQ
[
j_VKQ
])
+
__high2half
(
VKQ
[
j_VKQ
]));
if
(
parallel_blocks
==
1
)
{
dst_val
/=
kqsum
[
j_VKQ
];
}
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
parallel_blocks
+
ip
;
dst
[
j_dst
*
D
*
gridDim
.
y
+
D
*
blockIdx
.
y
+
tid
]
=
dst_val
;
}
if
(
parallel_blocks
!=
1
&&
tid
<
ncols
&&
(
ncols
<=
2
||
ic0
+
tid
<
ne01
))
{
dst_meta
[(
ic0
+
tid
)
*
gridDim
.
y
*
parallel_blocks
+
blockIdx
.
y
*
parallel_blocks
+
ip
]
=
make_float2
(
kqmax
[
tid
],
kqsum
[
tid
]);
}
#else
NO_DEVICE_CODE
;
#endif // FP16_AVAILABLE
}
void
ggml_cuda_flash_attn_ext_vec_f16
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
ggml_tensor
*
KQV
=
dst
;
ggml_tensor
*
Q
=
dst
->
src
[
0
];
const
int32_t
precision
=
KQV
->
op_params
[
2
];
GGML_ASSERT
(
precision
==
GGML_PREC_DEFAULT
);
constexpr
int
cols_per_block
=
1
;
constexpr
int
parallel_blocks
=
4
;
switch
(
Q
->
ne
[
0
])
{
case
64
:
{
constexpr
int
D
=
64
;
constexpr
int
nwarps
=
D
/
WARP_SIZE
;
fattn_kernel_t
fattn_kernel
=
flash_attn_vec_ext_f16
<
D
,
cols_per_block
,
parallel_blocks
>
;
launch_fattn
<
D
,
parallel_blocks
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
cols_per_block
);
}
break
;
case
128
:
{
constexpr
int
D
=
128
;
constexpr
int
nwarps
=
D
/
WARP_SIZE
;
fattn_kernel_t
fattn_kernel
=
flash_attn_vec_ext_f16
<
D
,
cols_per_block
,
parallel_blocks
>
;
launch_fattn
<
D
,
parallel_blocks
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
cols_per_block
);
}
break
;
case
256
:
{
constexpr
int
D
=
256
;
constexpr
int
nwarps
=
D
/
WARP_SIZE
;
fattn_kernel_t
fattn_kernel
=
flash_attn_vec_ext_f16
<
D
,
cols_per_block
,
parallel_blocks
>
;
launch_fattn
<
D
,
parallel_blocks
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
cols_per_block
);
}
break
;
default:
GGML_ASSERT
(
false
);
break
;
}
}
template
<
int
cols_per_block
,
int
parallel_blocks
>
void
launch_fattn_vec_f16_64_128
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
switch
(
Q
->
ne
[
0
])
{
case
64
:
{
constexpr
int
D
=
64
;
constexpr
int
nwarps
=
D
/
WARP_SIZE
;
fattn_kernel_t
fattn_kernel
=
flash_attn_vec_ext_f16
<
D
,
cols_per_block
,
parallel_blocks
>
;
launch_fattn
<
D
,
parallel_blocks
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
cols_per_block
);
}
break
;
case
128
:
{
constexpr
int
D
=
128
;
constexpr
int
nwarps
=
D
/
WARP_SIZE
;
fattn_kernel_t
fattn_kernel
=
flash_attn_vec_ext_f16
<
D
,
cols_per_block
,
parallel_blocks
>
;
launch_fattn
<
D
,
parallel_blocks
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
cols_per_block
);
}
break
;
default:
{
GGML_ASSERT
(
false
&&
"FlashAttention without tensor cores only supports head sizes 64 and 128."
);
}
break
;
}
}
void
ggml_cuda_flash_attn_ext_vec_f16_no_mma
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
KQV
=
dst
;
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
const
int32_t
precision
=
KQV
->
op_params
[
2
];
GGML_ASSERT
(
precision
==
GGML_PREC_DEFAULT
);
if
(
Q
->
ne
[
1
]
==
1
)
{
ggml_cuda_flash_attn_ext_vec_f16
(
ctx
,
dst
);
return
;
}
if
(
Q
->
ne
[
1
]
==
2
)
{
constexpr
int
cols_per_block
=
2
;
constexpr
int
parallel_blocks
=
4
;
launch_fattn_vec_f16_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
return
;
}
if
(
Q
->
ne
[
1
]
<=
4
)
{
constexpr
int
cols_per_block
=
4
;
constexpr
int
parallel_blocks
=
4
;
launch_fattn_vec_f16_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
return
;
}
if
(
Q
->
ne
[
1
]
<=
8
)
{
constexpr
int
cols_per_block
=
8
;
constexpr
int
parallel_blocks
=
4
;
launch_fattn_vec_f16_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
return
;
}
constexpr
int
cols_per_block
=
8
;
constexpr
int
parallel_blocks
=
1
;
launch_fattn_vec_f16_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
}
llm/llama.cpp/ggml-cuda/fattn-vec-f16.cuh
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
void
ggml_cuda_flash_attn_ext_vec_f16
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
void
ggml_cuda_flash_attn_ext_vec_f16_no_mma
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
llm/llama.cpp/ggml-cuda/fattn-vec-f32.cu
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-vec-f32.cuh"
template
<
int
D
,
int
ncols
,
int
parallel_blocks
>
// D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__
(
D
,
1
)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static
__global__
__launch_bounds__
(
1024
)
void
flash_attn_vec_ext_f32
(
const
char
*
__restrict__
Q
,
const
char
*
__restrict__
K
,
const
char
*
__restrict__
V
,
const
char
*
__restrict__
mask
,
float
*
__restrict__
dst
,
float2
*
__restrict__
dst_meta
,
const
float
scale
,
const
float
max_bias
,
const
float
m0
,
const
float
m1
,
const
uint32_t
n_head_log2
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
ne03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
ne13
,
const
int
ne31
,
const
int
nb31
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
const
int
ne0
,
const
int
ne1
,
const
int
ne2
,
const
int
ne3
)
{
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const
int
ic0
=
(
blockIdx
.
x
/
parallel_blocks
)
*
ncols
;
// Index of the Q/QKV column to work on.
const
int
ip
=
blockIdx
.
x
%
parallel_blocks
;
// Index in group of blocks running for the same column in parallel.
const
int
gqa_ratio
=
ne02
/
ne12
;
// With grouped query attention there are > 1 Q matrices per K, V matrix.
const
float2
*
Q_f2
=
(
const
float2
*
)
(
Q
+
nb02
*
blockIdx
.
y
+
nb01
*
ic0
);
const
half2
*
K_h2
=
(
const
half2
*
)
(
K
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
const
half
*
V_h
=
(
const
half
*
)
(
V
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
// K and V have same shape
const
half
*
maskh
=
(
const
half
*
)
mask
+
ne11
*
ic0
;
const
int
stride_KV
=
nb11
/
sizeof
(
half
);
const
int
stride_KV2
=
nb11
/
sizeof
(
half2
);
const
float
slope
=
get_alibi_slope
(
max_bias
,
blockIdx
.
y
,
n_head_log2
,
m0
,
m1
);
static_assert
(
D
%
(
2
*
WARP_SIZE
)
==
0
,
"D not divisible by 2*WARP_SIZE == 64."
);
constexpr
int
nwarps
=
D
/
WARP_SIZE
;
const
int
tid
=
WARP_SIZE
*
threadIdx
.
y
+
threadIdx
.
x
;
__builtin_assume
(
tid
<
D
);
__shared__
float
KQ
[
ncols
*
D
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
KQ
[
j
*
D
+
tid
]
=
-
FLT_MAX
/
2.0
f
;
}
float
kqmax
[
ncols
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
kqmax
[
j
]
=
-
FLT_MAX
/
2.0
f
;
}
float
kqsum
[
ncols
]
=
{
0.0
f
};
__shared__
float
kqmax_shared
[
ncols
][
WARP_SIZE
];
__shared__
float
kqsum_shared
[
ncols
][
WARP_SIZE
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
if
(
threadIdx
.
y
==
0
)
{
kqmax_shared
[
j
][
threadIdx
.
x
]
=
-
FLT_MAX
/
2.0
f
;
kqsum_shared
[
j
][
threadIdx
.
x
]
=
0.0
f
;
}
}
__syncthreads
();
// Convert Q to half2 and store in registers:
float2
Q_h2
[
ncols
][
D
/
(
2
*
WARP_SIZE
)];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
Q_h2
[
j
][
i0
/
WARP_SIZE
]
=
ncols
<=
2
||
ic0
+
j
?
Q_f2
[
j
*
(
nb01
/
sizeof
(
float2
))
+
i
]
:
make_float2
(
0.0
f
,
0.0
f
);
Q_h2
[
j
][
i0
/
WARP_SIZE
].
x
*=
scale
;
Q_h2
[
j
][
i0
/
WARP_SIZE
].
y
*=
scale
;
}
}
float
VKQ
[
ncols
]
=
{
0.0
f
};
const
int
k_start
=
parallel_blocks
==
1
?
0
:
ip
*
D
;
for
(
int
k_VKQ_0
=
k_start
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
parallel_blocks
*
D
)
{
// Calculate KQ tile and keep track of new maximum KQ values:
float
kqmax_new_arr
[
ncols
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
kqmax_new_arr
[
j
]
=
kqmax
[
j
];
}
#pragma unroll
for
(
int
i_KQ_0
=
0
;
i_KQ_0
<
D
;
i_KQ_0
+=
nwarps
)
{
const
int
i_KQ
=
i_KQ_0
+
threadIdx
.
y
;
if
((
i_KQ_0
+
nwarps
>
D
&&
i_KQ
>=
D
)
||
(
FATTN_KQ_STRIDE
%
D
!=
0
&&
k_VKQ_0
+
i_KQ
>=
ne11
))
{
break
;
}
float
sum
[
ncols
]
=
{
0.0
f
};
#pragma unroll
for
(
int
k_KQ_0
=
0
;
k_KQ_0
<
D
/
2
;
k_KQ_0
+=
WARP_SIZE
)
{
const
int
k_KQ
=
k_KQ_0
+
threadIdx
.
x
;
const
half2
K_ik
=
K_h2
[(
k_VKQ_0
+
i_KQ
)
*
stride_KV2
+
k_KQ
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
sum
[
j
]
+=
__low2float
(
K_ik
)
*
Q_h2
[
j
][
k_KQ_0
/
WARP_SIZE
].
x
;
sum
[
j
]
+=
__high2float
(
K_ik
)
*
Q_h2
[
j
][
k_KQ_0
/
WARP_SIZE
].
y
;
}
}
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
sum
[
j
]
=
warp_reduce_sum
(
sum
[
j
]);
sum
[
j
]
+=
mask
?
slope
*
__half2float
(
maskh
[
j
*
ne11
+
k_VKQ_0
+
i_KQ
])
:
0.0
f
;
kqmax_new_arr
[
j
]
=
fmaxf
(
kqmax_new_arr
[
j
],
sum
[
j
]);
if
(
threadIdx
.
x
==
0
)
{
KQ
[
j
*
D
+
i_KQ
]
=
sum
[
j
];
}
}
}
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
float
kqmax_new_j
=
kqmax_new_arr
[
j
];
kqmax_new_j
=
warp_reduce_max
(
kqmax_new_j
);
if
(
threadIdx
.
x
==
0
)
{
kqmax_shared
[
j
][
threadIdx
.
y
]
=
kqmax_new_j
;
}
}
__syncthreads
();
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
float
kqmax_new_j
=
kqmax_shared
[
j
][
threadIdx
.
x
];
kqmax_new_j
=
warp_reduce_max
(
kqmax_new_j
);
const
float
KQ_max_scale
=
expf
(
kqmax
[
j
]
-
kqmax_new_j
);
kqmax
[
j
]
=
kqmax_new_j
;
const
float
val
=
expf
(
KQ
[
j
*
D
+
tid
]
-
kqmax
[
j
]);
kqsum
[
j
]
=
kqsum
[
j
]
*
KQ_max_scale
+
val
;
KQ
[
j
*
D
+
tid
]
=
val
;
VKQ
[
j
]
*=
KQ_max_scale
;
}
__syncthreads
();
#pragma unroll
for
(
int
k
=
0
;
k
<
D
;
++
k
)
{
if
(
FATTN_KQ_STRIDE
%
D
!=
0
&&
k_VKQ_0
+
k
>=
ne11
)
{
break
;
}
const
float
V_ki
=
__half2float
(
V_h
[(
k_VKQ_0
+
k
)
*
stride_KV
+
tid
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
VKQ
[
j
]
+=
V_ki
*
KQ
[
j
*
D
+
k
];
}
}
__syncthreads
();
}
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
;
++
j
)
{
kqsum
[
j
]
=
warp_reduce_sum
(
kqsum
[
j
]);
if
(
threadIdx
.
x
==
0
)
{
kqsum_shared
[
j
][
threadIdx
.
y
]
=
kqsum
[
j
];
}
}
__syncthreads
();
#pragma unroll
for
(
int
j_VKQ
=
0
;
j_VKQ
<
ncols
;
++
j_VKQ
)
{
if
(
ncols
>
2
&&
ic0
+
j_VKQ
>=
ne01
)
{
break
;
}
kqsum
[
j_VKQ
]
=
kqsum_shared
[
j_VKQ
][
threadIdx
.
x
];
kqsum
[
j_VKQ
]
=
warp_reduce_sum
(
kqsum
[
j_VKQ
]);
float
dst_val
=
VKQ
[
j_VKQ
];
if
(
parallel_blocks
==
1
)
{
dst_val
/=
kqsum
[
j_VKQ
];
}
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
parallel_blocks
+
ip
;
dst
[
j_dst
*
D
*
gridDim
.
y
+
D
*
blockIdx
.
y
+
tid
]
=
dst_val
;
}
if
(
parallel_blocks
!=
1
&&
tid
<
ncols
&&
(
ncols
<=
2
||
ic0
+
tid
<
ne01
))
{
dst_meta
[(
ic0
+
tid
)
*
gridDim
.
y
*
parallel_blocks
+
blockIdx
.
y
*
parallel_blocks
+
ip
]
=
make_float2
(
kqmax
[
tid
],
kqsum
[
tid
]);
}
}
template
<
int
cols_per_block
,
int
parallel_blocks
>
void
launch_fattn_vec_f32_64_128
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
switch
(
Q
->
ne
[
0
])
{
case
64
:
{
constexpr
int
D
=
64
;
constexpr
int
nwarps
=
D
/
WARP_SIZE
;
fattn_kernel_t
fattn_kernel
=
flash_attn_vec_ext_f32
<
D
,
cols_per_block
,
parallel_blocks
>
;
launch_fattn
<
D
,
parallel_blocks
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
cols_per_block
);
}
break
;
case
128
:
{
constexpr
int
D
=
128
;
constexpr
int
nwarps
=
D
/
WARP_SIZE
;
fattn_kernel_t
fattn_kernel
=
flash_attn_vec_ext_f32
<
D
,
cols_per_block
,
parallel_blocks
>
;
launch_fattn
<
D
,
parallel_blocks
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
cols_per_block
);
}
break
;
default:
{
GGML_ASSERT
(
false
&&
"FlashAttention without tensor cores only supports head sizes 64 and 128."
);
}
break
;
}
}
void
ggml_cuda_flash_attn_ext_vec_f32
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
if
(
Q
->
ne
[
1
]
==
1
)
{
constexpr
int
cols_per_block
=
1
;
constexpr
int
parallel_blocks
=
4
;
launch_fattn_vec_f32_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
return
;
}
if
(
Q
->
ne
[
1
]
==
2
)
{
constexpr
int
cols_per_block
=
2
;
constexpr
int
parallel_blocks
=
4
;
launch_fattn_vec_f32_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
return
;
}
if
(
Q
->
ne
[
1
]
<=
4
)
{
constexpr
int
cols_per_block
=
4
;
constexpr
int
parallel_blocks
=
4
;
launch_fattn_vec_f32_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
return
;
}
if
(
Q
->
ne
[
1
]
<=
8
)
{
constexpr
int
cols_per_block
=
8
;
constexpr
int
parallel_blocks
=
4
;
launch_fattn_vec_f32_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
return
;
}
constexpr
int
cols_per_block
=
8
;
constexpr
int
parallel_blocks
=
1
;
launch_fattn_vec_f32_64_128
<
cols_per_block
,
parallel_blocks
>
(
ctx
,
dst
);
}
llm/llama.cpp/ggml-cuda/fattn-vec-f32.cuh
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
void
ggml_cuda_flash_attn_ext_vec_f32
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
llm/llama.cpp/ggml-cuda/fattn.cu
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile-f16.cuh"
#include "fattn-tile-f32.cuh"
#include "fattn-vec-f16.cuh"
#include "fattn-vec-f32.cuh"
#include "fattn.cuh"
#include <cstdint>
#if FP16_MMA_AVAILABLE
#include <mma.h>
#endif
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template
<
int
D
,
int
ncols
,
int
nwarps
,
int
VKQ_stride
,
int
parallel_blocks
,
typename
KQ_acc_t
>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__
(
nwarps
*
WARP_SIZE
,
1
)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static
__global__
__launch_bounds__
(
1024
)
void
flash_attn_ext_f16
(
const
char
*
__restrict__
Q
,
const
char
*
__restrict__
K
,
const
char
*
__restrict__
V
,
const
char
*
__restrict__
mask
,
float
*
__restrict__
dst
,
float2
*
__restrict__
dst_meta
,
const
float
scale
,
const
float
max_bias
,
const
float
m0
,
const
float
m1
,
const
uint32_t
n_head_log2
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
ne03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
ne13
,
const
int
ne31
,
const
int
nb31
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
const
int
ne0
,
const
int
ne1
,
const
int
ne2
,
const
int
ne3
)
{
#if FP16_MMA_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const
int
ic0
=
ncols
*
(
blockIdx
.
x
/
parallel_blocks
);
// Index of the first Q/QKV column to work on.
const
int
ip
=
blockIdx
.
x
%
parallel_blocks
;
// Index in group of blocks running for the same column in parallel.
static_assert
(
D
<=
FATTN_KQ_STRIDE
,
"D must be <= FATTN_KQ_STRIDE."
);
static_assert
(
ncols
==
8
||
ncols
%
16
==
0
,
"ncols must be 8 or a multiple of 16."
);
constexpr
int
frag_m
=
ncols
==
8
?
32
:
16
;
constexpr
int
frag_n
=
ncols
==
8
?
8
:
16
;
static_assert
(
D
%
frag_m
==
0
,
"If ncols == 8 then D % frag_m must be 0."
);
typedef
nvcuda
::
wmma
::
fragment
<
nvcuda
::
wmma
::
matrix_a
,
frag_m
,
frag_n
,
16
,
half
,
nvcuda
::
wmma
::
row_major
>
frag_a_K
;
typedef
nvcuda
::
wmma
::
fragment
<
nvcuda
::
wmma
::
matrix_a
,
frag_m
,
frag_n
,
16
,
half
,
nvcuda
::
wmma
::
col_major
>
frag_a_V
;
typedef
nvcuda
::
wmma
::
fragment
<
nvcuda
::
wmma
::
matrix_b
,
frag_m
,
frag_n
,
16
,
half
,
nvcuda
::
wmma
::
col_major
>
frag_b
;
typedef
nvcuda
::
wmma
::
fragment
<
nvcuda
::
wmma
::
accumulator
,
frag_m
,
frag_n
,
16
,
KQ_acc_t
>
frag_c_KQ
;
typedef
nvcuda
::
wmma
::
fragment
<
nvcuda
::
wmma
::
accumulator
,
frag_m
,
frag_n
,
16
,
half
>
frag_c_VKQ
;
constexpr
int
KQ_stride_tc
=
nwarps
*
frag_m
;
// Number of KQ rows calculated in parallel.
constexpr
int
VKQ_ratio
=
KQ_stride_tc
/
VKQ_stride
;
// Number of parallel VKQ accumulators needed to keep all warps busy.
static_assert
(
VKQ_ratio
<=
nwarps
,
"VKQ_ratio must be <= nwarps."
);
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
constexpr
int
D_padded
=
D
+
8
;
constexpr
int
kqs_padded
=
FATTN_KQ_STRIDE
+
8
;
constexpr
int
kqar
=
sizeof
(
KQ_acc_t
)
/
sizeof
(
half
);
const
int
gqa_ratio
=
ne02
/
ne12
;
// With grouped query attention there are > 1 Q matrices per K, V matrix.
const
float
*
Q_f
=
(
const
float
*
)
(
Q
+
nb02
*
blockIdx
.
y
+
nb01
*
ic0
);
const
half
*
K_h
=
(
const
half
*
)
(
K
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
const
half
*
V_h
=
(
const
half
*
)
(
V
+
nb12
*
(
blockIdx
.
y
/
gqa_ratio
));
// K and V have same shape
const
half
*
maskh
=
(
const
half
*
)
mask
+
(
nb31
/
sizeof
(
half
))
*
ic0
;
const
half2
*
mask2
=
(
const
half2
*
)
mask
+
(
nb31
/
sizeof
(
half
))
*
(
ic0
/
2
);
const
int
stride_Q
=
nb01
/
sizeof
(
float
);
const
int
stride_KV
=
nb11
/
sizeof
(
half
);
const
float
slopef
=
get_alibi_slope
(
max_bias
,
blockIdx
.
y
,
n_head_log2
,
m0
,
m1
);
const
half
slopeh
=
__float2half
(
slopef
);
const
half2
slope2
=
make_half2
(
slopef
,
slopef
);
frag_b
Q_b
[
D
/
16
][
ncols
/
frag_n
];
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
constexpr
int
mem_KQ
=
ncols
*
kqs_padded
*
kqar
;
constexpr
int
mem_VKQ_parts
=
VKQ_ratio
*
ncols
*
D_padded
;
__shared__
half
KQ
[
mem_KQ
>=
mem_VKQ_parts
?
mem_KQ
:
mem_VKQ_parts
];
float
*
KQ_f
=
(
float
*
)
KQ
;
half2
*
KQ2
=
(
half2
*
)
KQ
;
float
KQ_rowsum_f
[
ncols
/
nwarps
]
=
{
0.0
f
};
float
KQ_max_f
[
ncols
/
nwarps
];
float
KQ_max_scale_f
[
ncols
/
nwarps
]
=
{
0.0
f
};
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
/
nwarps
;
++
j
)
{
KQ_max_f
[
j
]
=
-
FLT_MAX
/
2.0
f
;
}
half2
KQ_rowsum_h2
[
ncols
/
nwarps
]
=
{{
0.0
f
,
0.0
f
}};
half2
KQ_max_h2
[
ncols
/
nwarps
];
half2
KQ_max_scale_h2
[
ncols
/
nwarps
]
=
{{
0.0
f
,
0.0
f
}};
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
/
nwarps
;
++
j
)
{
KQ_max_h2
[
j
]
=
make_half2
(
-
HALF_MAX_HALF
,
-
HALF_MAX_HALF
);
}
__shared__
half
VKQ
[
ncols
*
D_padded
];
// Accumulator for final VKQ slice.
half2
*
VKQ2
=
(
half2
*
)
VKQ
;
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
const
int
j
=
j0
+
threadIdx
.
y
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
if
(
i0
+
WARP_SIZE
>
D
/
2
&&
i
>=
D
/
2
)
{
break
;
}
VKQ2
[
j
*
(
D_padded
/
2
)
+
i
]
=
make_half2
(
0.0
f
,
0.0
f
);
}
}
// Convert Q to half and apply scale, temporarily store in KQ:
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
const
int
j
=
j0
+
threadIdx
.
y
;
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
if
(
i0
+
WARP_SIZE
>
D
&&
i
>=
D
)
{
break
;
}
KQ
[
j
*
D_padded
+
i
]
=
ic0
+
j
<
ne01
?
Q_f
[
j
*
stride_Q
+
i
]
*
scale
:
0.0
f
;
}
}
__syncthreads
();
// Load Q into tensor core fragments/registers since it will be used frequently:
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
;
i0
+=
16
)
{
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
frag_n
)
{
nvcuda
::
wmma
::
load_matrix_sync
(
Q_b
[
i0
/
16
][
j0
/
frag_n
],
KQ
+
j0
*
D_padded
+
i0
,
D_padded
);
}
}
__syncthreads
();
// Iterate over ne11 == previous tokens:
for
(
int
k_VKQ_0
=
ip
*
FATTN_KQ_STRIDE
;
k_VKQ_0
<
ne11
;
k_VKQ_0
+=
parallel_blocks
*
FATTN_KQ_STRIDE
)
{
// Calculate tile of KQ:
#pragma unroll
for
(
int
i_KQ_0
=
0
;
i_KQ_0
<
FATTN_KQ_STRIDE
;
i_KQ_0
+=
KQ_stride_tc
)
{
frag_c_KQ
KQ_c
[
ncols
/
frag_n
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
/
frag_n
;
++
j
)
{
nvcuda
::
wmma
::
fill_fragment
(
KQ_c
[
j
],
0.0
f
);
}
#pragma unroll
for
(
int
k_KQ_0
=
0
;
k_KQ_0
<
D
;
k_KQ_0
+=
16
)
{
frag_a_K
K_a
;
nvcuda
::
wmma
::
load_matrix_sync
(
K_a
,
K_h
+
(
k_VKQ_0
+
i_KQ_0
+
frag_m
*
threadIdx
.
y
)
*
stride_KV
+
k_KQ_0
,
stride_KV
);
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
/
frag_n
;
++
j
)
{
nvcuda
::
wmma
::
mma_sync
(
KQ_c
[
j
],
K_a
,
Q_b
[
k_KQ_0
/
16
][
j
],
KQ_c
[
j
]);
}
}
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
frag_n
)
{
nvcuda
::
wmma
::
store_matrix_sync
((
KQ_acc_t
*
)
KQ
+
j0
*
kqs_padded
+
i_KQ_0
+
frag_m
*
threadIdx
.
y
,
KQ_c
[
j0
/
frag_n
],
kqs_padded
,
nvcuda
::
wmma
::
mem_col_major
);
}
}
__syncthreads
();
// Calculate softmax for each KQ column using the current max. value.
// The divisor is stored in KQ_rowsum and will be applied at the end.
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
const
int
j
=
j0
+
threadIdx
.
y
;
if
(
std
::
is_same
<
KQ_acc_t
,
float
>::
value
)
{
float
KQ_f_tmp
[
FATTN_KQ_STRIDE
/
WARP_SIZE
];
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
;
k0
+=
WARP_SIZE
)
{
const
int
k
=
k0
+
threadIdx
.
x
;
KQ_f_tmp
[
k0
/
WARP_SIZE
]
=
KQ_f
[
j
*
kqs_padded
+
k
];
}
float
KQ_max_new
=
KQ_max_f
[
j0
/
nwarps
];
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
;
k0
+=
WARP_SIZE
)
{
const
int
k
=
k0
+
threadIdx
.
x
;
KQ_f_tmp
[
k0
/
WARP_SIZE
]
+=
mask
?
__half2float
(
slopeh
*
maskh
[
j
*
(
nb31
/
sizeof
(
half
))
+
k_VKQ_0
+
k
])
:
0.0
f
;
KQ_max_new
=
max
(
KQ_max_new
,
KQ_f_tmp
[
k0
/
WARP_SIZE
]);
}
KQ_max_new
=
warp_reduce_max
(
KQ_max_new
);
const
float
diff
=
KQ_max_f
[
j0
/
nwarps
]
-
KQ_max_new
;
KQ_max_scale_f
[
j0
/
nwarps
]
=
expf
(
diff
);
if
(
diff
<=
SOFTMAX_FTZ_THRESHOLD
)
{
KQ_max_scale_f
[
j0
/
nwarps
]
=
0.0
f
;
}
KQ_max_f
[
j0
/
nwarps
]
=
KQ_max_new
;
float
KQ_rowsum_add
=
0.0
f
;
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
;
k0
+=
WARP_SIZE
)
{
const
int
k
=
k0
+
threadIdx
.
x
;
const
float
diff
=
KQ_f_tmp
[
k0
/
WARP_SIZE
]
-
KQ_max_f
[
j0
/
nwarps
];
KQ_f_tmp
[
k0
/
WARP_SIZE
]
=
expf
(
diff
);
if
(
diff
<=
SOFTMAX_FTZ_THRESHOLD
)
{
KQ_f_tmp
[
k0
/
WARP_SIZE
]
=
0.0
f
;
}
KQ_rowsum_add
+=
KQ_f_tmp
[
k0
/
WARP_SIZE
];
KQ
[
j
*
(
kqar
*
kqs_padded
)
+
k
]
=
KQ_f_tmp
[
k0
/
WARP_SIZE
];
}
KQ_rowsum_add
=
warp_reduce_sum
(
KQ_rowsum_add
);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_f
[
j0
/
nwarps
]
=
KQ_max_scale_f
[
j0
/
nwarps
]
*
KQ_rowsum_f
[
j0
/
nwarps
]
+
KQ_rowsum_add
;
}
else
{
half2
KQ2_tmp
[
FATTN_KQ_STRIDE
/
(
2
*
WARP_SIZE
)];
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
/
2
;
k0
+=
WARP_SIZE
)
{
const
int
k
=
k0
+
threadIdx
.
x
;
KQ2_tmp
[
k0
/
WARP_SIZE
]
=
KQ2
[
j
*
(
kqs_padded
/
2
)
+
k
];
}
half2
KQ_max_new
=
KQ_max_h2
[
j0
/
nwarps
];
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
/
2
;
k0
+=
WARP_SIZE
)
{
const
int
k
=
k0
+
threadIdx
.
x
;
KQ2_tmp
[
k0
/
WARP_SIZE
]
+=
mask
?
slope2
*
mask2
[(
j
*
ne11
+
k_VKQ_0
)
/
2
+
k
]
:
make_half2
(
0.0
f
,
0.0
f
);
KQ_max_new
=
ggml_cuda_hmax2
(
KQ_max_new
,
KQ2_tmp
[
k0
/
WARP_SIZE
]);
}
KQ_max_new
=
__half2half2
(
warp_reduce_max
(
ggml_cuda_hmax
(
__low2half
(
KQ_max_new
),
__high2half
(
KQ_max_new
))));
const
half2
diff
=
KQ_max_h2
[
j0
/
nwarps
]
-
KQ_max_new
;
KQ_max_scale_h2
[
j0
/
nwarps
]
=
h2exp
(
diff
);
const
uint32_t
ftz_mask
=
__hgt2_mask
(
diff
,
make_half2
(
SOFTMAX_FTZ_THRESHOLD
,
SOFTMAX_FTZ_THRESHOLD
));
*
((
uint32_t
*
)
&
KQ_max_scale_h2
[
j0
/
nwarps
])
&=
ftz_mask
;
KQ_max_h2
[
j0
/
nwarps
]
=
KQ_max_new
;
half2
KQ_rowsum_add
=
make_half2
(
0.0
f
,
0.0
f
);
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
/
2
;
k0
+=
WARP_SIZE
)
{
const
int
k
=
k0
+
threadIdx
.
x
;
const
half2
diff
=
KQ2_tmp
[
k0
/
WARP_SIZE
]
-
KQ_max_h2
[
j0
/
nwarps
];
KQ2_tmp
[
k0
/
WARP_SIZE
]
=
h2exp
(
diff
);
const
uint32_t
ftz_mask
=
__hgt2_mask
(
diff
,
make_half2
(
SOFTMAX_FTZ_THRESHOLD
,
SOFTMAX_FTZ_THRESHOLD
));
*
((
uint32_t
*
)
&
KQ2_tmp
[
k0
/
WARP_SIZE
])
&=
ftz_mask
;
KQ_rowsum_add
+=
KQ2_tmp
[
k0
/
WARP_SIZE
];
KQ2
[
j
*
(
kqs_padded
/
2
)
+
k
]
=
KQ2_tmp
[
k0
/
WARP_SIZE
];
}
KQ_rowsum_add
=
warp_reduce_sum
(
KQ_rowsum_add
);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_h2
[
j0
/
nwarps
]
=
KQ_max_scale_h2
[
j0
/
nwarps
]
*
KQ_rowsum_h2
[
j0
/
nwarps
]
+
KQ_rowsum_add
;
}
}
__syncthreads
();
frag_b
KQ_b
[
FATTN_KQ_STRIDE
/
(
VKQ_ratio
*
16
)][
ncols
/
frag_n
];
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
frag_n
)
{
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
;
k0
+=
VKQ_ratio
*
16
)
{
const
int
k
=
k0
+
(
threadIdx
.
y
%
VKQ_ratio
)
*
16
;
nvcuda
::
wmma
::
load_matrix_sync
(
KQ_b
[
k0
/
(
VKQ_ratio
*
16
)][
j0
/
frag_n
],
KQ
+
j0
*
(
kqar
*
kqs_padded
)
+
k
,
kqar
*
kqs_padded
);
}
}
frag_c_VKQ
VKQ_c
[
D
/
VKQ_stride
][
ncols
/
frag_n
];
#pragma unroll
for
(
int
i_VKQ_0
=
0
;
i_VKQ_0
<
D
;
i_VKQ_0
+=
VKQ_stride
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
/
frag_n
;
++
j
)
{
nvcuda
::
wmma
::
fill_fragment
(
VKQ_c
[
i_VKQ_0
/
VKQ_stride
][
j
],
0.0
f
);
}
#pragma unroll
for
(
int
k0
=
0
;
k0
<
FATTN_KQ_STRIDE
;
k0
+=
VKQ_ratio
*
16
)
{
const
int
k
=
k0
+
(
threadIdx
.
y
%
VKQ_ratio
)
*
16
;
frag_a_V
v_a
;
nvcuda
::
wmma
::
load_matrix_sync
(
v_a
,
V_h
+
(
k_VKQ_0
+
k
)
*
stride_KV
+
i_VKQ_0
+
frag_m
*
(
threadIdx
.
y
/
VKQ_ratio
),
stride_KV
);
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols
/
frag_n
;
++
j
)
{
nvcuda
::
wmma
::
mma_sync
(
VKQ_c
[
i_VKQ_0
/
VKQ_stride
][
j
],
v_a
,
KQ_b
[
k0
/
(
VKQ_ratio
*
16
)][
j
],
VKQ_c
[
i_VKQ_0
/
VKQ_stride
][
j
]);
}
}
}
__syncthreads
();
const
int
offset_k
=
(
threadIdx
.
y
%
VKQ_ratio
)
*
(
ncols
*
D_padded
);
#pragma unroll
for
(
int
i_KQ_0
=
0
;
i_KQ_0
<
D
;
i_KQ_0
+=
VKQ_stride
)
{
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
frag_n
)
{
nvcuda
::
wmma
::
store_matrix_sync
(
KQ
+
offset_k
+
j0
*
D_padded
+
i_KQ_0
+
frag_m
*
(
threadIdx
.
y
/
VKQ_ratio
),
VKQ_c
[
i_KQ_0
/
VKQ_stride
][
j0
/
frag_n
],
D_padded
,
nvcuda
::
wmma
::
mem_col_major
);
}
}
__syncthreads
();
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
const
int
j
=
j0
+
threadIdx
.
y
;
half2
VKQ_scale
;
if
(
std
::
is_same
<
KQ_acc_t
,
float
>::
value
)
{
VKQ_scale
=
make_half2
(
KQ_max_scale_f
[
j0
/
nwarps
],
KQ_max_scale_f
[
j0
/
nwarps
]);
}
else
{
VKQ_scale
=
KQ_max_scale_h2
[
j0
/
nwarps
];
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
/
2
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
if
(
i0
+
WARP_SIZE
>
D
/
2
&&
i
>=
D
/
2
)
{
break
;
}
half2
VKQ_add
=
make_half2
(
0.0
f
,
0.0
f
);
#pragma unroll
for
(
int
l
=
0
;
l
<
VKQ_ratio
;
++
l
)
{
VKQ_add
+=
KQ2
[
l
*
(
ncols
*
D_padded
/
2
)
+
j
*
(
D_padded
/
2
)
+
i
];
}
VKQ2
[
j
*
(
D_padded
/
2
)
+
i
]
=
VKQ_scale
*
VKQ2
[
j
*
(
D_padded
/
2
)
+
i
]
+
VKQ_add
;
}
}
__syncthreads
();
}
#pragma unroll
for
(
int
j0
=
0
;
j0
<
ncols
;
j0
+=
nwarps
)
{
const
int
j_VKQ
=
j0
+
threadIdx
.
y
;
if
(
ic0
+
j_VKQ
>=
ne01
)
{
return
;
}
const
int
j_dst
=
(
ic0
+
j_VKQ
)
*
parallel_blocks
+
ip
;
float
KQ_rowsum_j
;
if
(
std
::
is_same
<
KQ_acc_t
,
float
>::
value
)
{
KQ_rowsum_j
=
KQ_rowsum_f
[
j0
/
nwarps
];
}
else
{
KQ_rowsum_j
=
__low2float
(
KQ_rowsum_h2
[
j0
/
nwarps
])
+
__high2float
(
KQ_rowsum_h2
[
j0
/
nwarps
]);
}
#pragma unroll
for
(
int
i0
=
0
;
i0
<
D
;
i0
+=
WARP_SIZE
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
if
(
i0
+
WARP_SIZE
>
D
&&
i
>=
D
)
{
break
;
}
float
dst_val
=
VKQ
[
j_VKQ
*
D_padded
+
i
];
if
(
parallel_blocks
==
1
)
{
dst_val
/=
KQ_rowsum_j
;
}
dst
[
j_dst
*
gridDim
.
y
*
D
+
blockIdx
.
y
*
D
+
i
]
=
dst_val
;
}
if
(
parallel_blocks
==
1
||
threadIdx
.
x
!=
0
)
{
continue
;
}
float2
dst_meta_val
;
if
(
std
::
is_same
<
KQ_acc_t
,
float
>::
value
)
{
dst_meta_val
.
x
=
KQ_max_f
[
j0
/
nwarps
];
}
else
{
dst_meta_val
.
x
=
__low2float
(
KQ_max_h2
[
j0
/
nwarps
]);
}
dst_meta_val
.
y
=
KQ_rowsum_j
;
dst_meta
[(
ic0
+
j_VKQ
)
*
gridDim
.
y
*
parallel_blocks
+
blockIdx
.
y
*
parallel_blocks
+
ip
]
=
dst_meta_val
;
}
#else
NO_DEVICE_CODE
;
#endif // FP16_MMA_AVAILABLE
}
constexpr
int
get_max_power_of_2
(
int
x
)
{
return
x
%
2
==
0
?
2
*
get_max_power_of_2
(
x
/
2
)
:
1
;
}
static_assert
(
get_max_power_of_2
(
1
)
==
1
,
"Test failed."
);
static_assert
(
get_max_power_of_2
(
2
)
==
2
,
"Test failed."
);
static_assert
(
get_max_power_of_2
(
4
)
==
4
,
"Test failed."
);
static_assert
(
get_max_power_of_2
(
6
)
==
2
,
"Test failed."
);
// Number of VKQ rows calculated in parallel:
constexpr
int
get_VKQ_stride
(
int
D
,
int
nwarps
,
int
frag_m
)
{
return
(
get_max_power_of_2
(
D
/
frag_m
)
<
nwarps
?
get_max_power_of_2
(
D
/
frag_m
)
:
nwarps
)
*
frag_m
;
}
static_assert
(
get_VKQ_stride
(
128
,
1
,
32
)
==
32
,
"Test failed."
);
static_assert
(
get_VKQ_stride
(
128
,
2
,
32
)
==
64
,
"Test failed."
);
static_assert
(
get_VKQ_stride
(
128
,
4
,
32
)
==
128
,
"Test failed."
);
static_assert
(
get_VKQ_stride
(
64
,
1
,
32
)
==
32
,
"Test failed."
);
static_assert
(
get_VKQ_stride
(
64
,
2
,
32
)
==
64
,
"Test failed."
);
static_assert
(
get_VKQ_stride
(
64
,
4
,
32
)
==
64
,
"Test failed."
);
static_assert
(
get_VKQ_stride
(
80
,
1
,
16
)
==
16
,
"Test failed."
);
static_assert
(
get_VKQ_stride
(
80
,
2
,
16
)
==
16
,
"Test failed."
);
static_assert
(
get_VKQ_stride
(
80
,
4
,
16
)
==
16
,
"Test failed."
);
template
<
int
D
,
int
cols_per_block
,
int
nwarps
,
typename
KQ_acc_t
>
void
launch_fattn_f16
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
constexpr
int
frag_m
=
cols_per_block
==
8
&&
D
%
32
==
0
?
32
:
16
;
const
int
blocks_num_pb1
=
((
Q
->
ne
[
1
]
+
cols_per_block
-
1
)
/
cols_per_block
)
*
Q
->
ne
[
2
]
*
Q
->
ne
[
3
];
const
int
nsm
=
ggml_cuda_info
().
devices
[
ggml_cuda_get_device
()].
nsm
;
if
(
4
*
blocks_num_pb1
<
2
*
nsm
)
{
constexpr
int
parallel_blocks
=
4
;
fattn_kernel_t
fattn_kernel
=
flash_attn_ext_f16
<
D
,
cols_per_block
,
nwarps
,
get_VKQ_stride
(
D
,
nwarps
,
frag_m
),
parallel_blocks
,
KQ_acc_t
>
;
launch_fattn
<
D
,
parallel_blocks
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
cols_per_block
);
return
;
}
if
(
2
*
blocks_num_pb1
<
2
*
nsm
)
{
constexpr
int
parallel_blocks
=
2
;
fattn_kernel_t
fattn_kernel
=
flash_attn_ext_f16
<
D
,
cols_per_block
,
nwarps
,
get_VKQ_stride
(
D
,
nwarps
,
frag_m
),
parallel_blocks
,
KQ_acc_t
>
;
launch_fattn
<
D
,
parallel_blocks
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
cols_per_block
);
return
;
}
constexpr
int
parallel_blocks
=
1
;
fattn_kernel_t
fattn_kernel
=
flash_attn_ext_f16
<
D
,
cols_per_block
,
nwarps
,
get_VKQ_stride
(
D
,
nwarps
,
frag_m
),
parallel_blocks
,
KQ_acc_t
>
;
launch_fattn
<
D
,
parallel_blocks
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
cols_per_block
);
}
void
ggml_cuda_flash_attn_ext
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
KQV
=
dst
;
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
ggml_cuda_set_device
(
ctx
.
device
);
const
int
cc
=
ggml_cuda_info
().
devices
[
ggml_cuda_get_device
()].
cc
;
const
int32_t
precision
=
KQV
->
op_params
[
2
];
// On AMD the tile kernels perform poorly, use the vec kernel instead:
if
(
cc
>=
CC_OFFSET_AMD
)
{
if
(
precision
==
GGML_PREC_DEFAULT
)
{
ggml_cuda_flash_attn_ext_vec_f16_no_mma
(
ctx
,
dst
);
}
else
{
ggml_cuda_flash_attn_ext_vec_f32
(
ctx
,
dst
);
}
return
;
}
if
(
!
fast_fp16_available
(
cc
))
{
if
(
Q
->
ne
[
1
]
<=
8
)
{
ggml_cuda_flash_attn_ext_vec_f32
(
ctx
,
dst
);
}
else
{
ggml_cuda_flash_attn_ext_tile_f32
(
ctx
,
dst
);
}
return
;
}
if
(
!
fp16_mma_available
(
cc
))
{
if
(
Q
->
ne
[
1
]
<=
8
)
{
ggml_cuda_flash_attn_ext_vec_f16_no_mma
(
ctx
,
dst
);
}
else
{
ggml_cuda_flash_attn_ext_tile_f16
(
ctx
,
dst
);
}
return
;
}
if
(
precision
!=
GGML_PREC_DEFAULT
)
{
if
(
Q
->
ne
[
1
]
==
1
&&
(
Q
->
ne
[
0
]
==
64
||
Q
->
ne
[
0
]
==
128
))
{
ggml_cuda_flash_attn_ext_vec_f32
(
ctx
,
dst
);
return
;
}
if
(
Q
->
ne
[
1
]
<=
32
||
Q
->
ne
[
0
]
>
128
)
{
constexpr
int
cols_per_block
=
16
;
constexpr
int
nwarps
=
4
;
switch
(
Q
->
ne
[
0
])
{
case
64
:
launch_fattn_f16
<
64
,
cols_per_block
,
nwarps
,
float
>
(
ctx
,
dst
);
break
;
case
80
:
launch_fattn_f16
<
80
,
cols_per_block
,
nwarps
,
float
>
(
ctx
,
dst
);
break
;
case
96
:
launch_fattn_f16
<
96
,
cols_per_block
,
nwarps
,
float
>
(
ctx
,
dst
);
break
;
case
112
:
launch_fattn_f16
<
112
,
cols_per_block
,
nwarps
,
float
>
(
ctx
,
dst
);
break
;
case
128
:
launch_fattn_f16
<
128
,
cols_per_block
,
nwarps
,
float
>
(
ctx
,
dst
);
break
;
case
256
:
launch_fattn_f16
<
256
,
cols_per_block
,
nwarps
,
float
>
(
ctx
,
dst
);
break
;
default:
GGML_ASSERT
(
false
);
break
;
}
}
else
{
constexpr
int
cols_per_block
=
32
;
constexpr
int
nwarps
=
4
;
switch
(
Q
->
ne
[
0
])
{
case
64
:
launch_fattn_f16
<
64
,
cols_per_block
,
nwarps
,
float
>
(
ctx
,
dst
);
break
;
case
80
:
launch_fattn_f16
<
80
,
cols_per_block
,
nwarps
,
float
>
(
ctx
,
dst
);
break
;
case
96
:
launch_fattn_f16
<
96
,
cols_per_block
,
nwarps
,
float
>
(
ctx
,
dst
);
break
;
case
112
:
launch_fattn_f16
<
112
,
cols_per_block
,
nwarps
,
float
>
(
ctx
,
dst
);
break
;
case
128
:
launch_fattn_f16
<
128
,
cols_per_block
,
nwarps
,
float
>
(
ctx
,
dst
);
break
;
// case 256:
// launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
// break;
default:
GGML_ASSERT
(
false
);
break
;
}
}
return
;
}
if
(
Q
->
ne
[
1
]
==
1
&&
Q
->
ne
[
0
]
%
(
2
*
WARP_SIZE
)
==
0
)
{
ggml_cuda_flash_attn_ext_vec_f16
(
ctx
,
dst
);
return
;
}
if
(
Q
->
ne
[
1
]
<=
8
&&
Q
->
ne
[
0
]
%
WARP_SIZE
==
0
)
{
constexpr
int
cols_per_block
=
8
;
constexpr
int
nwarps
=
4
;
switch
(
Q
->
ne
[
0
])
{
case
64
:
launch_fattn_f16
<
64
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
case
96
:
launch_fattn_f16
<
96
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
case
128
:
launch_fattn_f16
<
128
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
case
256
:
launch_fattn_f16
<
256
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
default:
GGML_ASSERT
(
false
);
break
;
}
return
;
}
if
(
Q
->
ne
[
1
]
<=
32
)
{
constexpr
int
cols_per_block
=
16
;
constexpr
int
nwarps
=
4
;
switch
(
Q
->
ne
[
0
])
{
case
64
:
launch_fattn_f16
<
64
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
case
80
:
launch_fattn_f16
<
80
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
case
96
:
launch_fattn_f16
<
96
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
case
112
:
launch_fattn_f16
<
112
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
case
128
:
launch_fattn_f16
<
128
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
case
256
:
launch_fattn_f16
<
256
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
default:
GGML_ASSERT
(
false
);
break
;
}
return
;
}
constexpr
int
cols_per_block
=
32
;
constexpr
int
nwarps
=
4
;
switch
(
Q
->
ne
[
0
])
{
case
64
:
launch_fattn_f16
<
64
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
case
80
:
launch_fattn_f16
<
80
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
case
96
:
launch_fattn_f16
<
96
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
case
112
:
launch_fattn_f16
<
112
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
case
128
:
launch_fattn_f16
<
128
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
case
256
:
launch_fattn_f16
<
256
,
cols_per_block
,
nwarps
,
half
>
(
ctx
,
dst
);
break
;
default:
GGML_ASSERT
(
false
);
break
;
}
return
;
}
llm/llama.cpp/ggml-cuda/fattn.cuh
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
void
ggml_cuda_flash_attn_ext
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
llm/llama.cpp/ggml-cuda/getrows.cu
deleted
100644 → 0
View file @
97b02a89
#include "getrows.cuh"
#include "dequantize.cuh"
template
<
int
qk
,
int
qr
,
dequantize_kernel_t
dequantize_kernel
,
typename
dst_t
>
static
__global__
__launch_bounds__
(
1024
)
void
k_get_rows
(
const
void
*
src0
,
const
int32_t
*
src1
,
dst_t
*
dst
,
int64_t
ne00
,
/*int64_t ne01, int64_t ne02, int64_t ne03,*/
/*int64_t ne10, int64_t ne11,*/
int64_t
ne12
,
/*int64_t ne13,*/
/*size_t s0,*/
size_t
s1
,
size_t
s2
,
size_t
s3
,
/*size_t nb00,*/
size_t
nb01
,
size_t
nb02
,
size_t
nb03
,
size_t
s10
,
size_t
s11
,
size_t
s12
/*, size_t s13*/
)
{
const
int
i00
=
(
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
)
*
2
;
const
int
i10
=
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
;
const
int
i11
=
(
blockIdx
.
z
*
blockDim
.
z
+
threadIdx
.
z
)
/
ne12
;
const
int
i12
=
(
blockIdx
.
z
*
blockDim
.
z
+
threadIdx
.
z
)
%
ne12
;
if
(
i00
>=
ne00
)
{
return
;
}
const
int
i01
=
src1
[
i10
*
s10
+
i11
*
s11
+
i12
*
s12
];
dst_t
*
dst_row
=
dst
+
i10
*
s1
+
i11
*
s2
+
i12
*
s3
;
const
void
*
src0_row
=
(
const
char
*
)
src0
+
i01
*
nb01
+
i11
*
nb02
+
i12
*
nb03
;
const
int
ib
=
i00
/
qk
;
// block index
const
int
iqs
=
(
i00
%
qk
)
/
qr
;
// quant index
const
int
iybs
=
i00
-
i00
%
qk
;
// dst block start index
const
int
y_offset
=
qr
==
1
?
1
:
qk
/
2
;
// dequantize
dfloat2
v
;
dequantize_kernel
(
src0_row
,
ib
,
iqs
,
v
);
dst_row
[
iybs
+
iqs
+
0
]
=
v
.
x
;
dst_row
[
iybs
+
iqs
+
y_offset
]
=
v
.
y
;
}
template
<
typename
src0_t
,
typename
dst_t
>
static
__global__
__launch_bounds__
(
1024
)
void
k_get_rows_float
(
const
src0_t
*
src0
,
const
int32_t
*
src1
,
dst_t
*
dst
,
int64_t
ne00
,
/*int64_t ne01, int64_t ne02, int64_t ne03,*/
/*int64_t ne10, int64_t ne11,*/
int64_t
ne12
,
/*int64_t ne13,*/
/*size_t s0,*/
size_t
s1
,
size_t
s2
,
size_t
s3
,
/*size_t nb00,*/
size_t
nb01
,
size_t
nb02
,
size_t
nb03
,
size_t
s10
,
size_t
s11
,
size_t
s12
/*, size_t s13*/
)
{
const
int
i00
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i10
=
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
;
const
int
i11
=
(
blockIdx
.
z
*
blockDim
.
z
+
threadIdx
.
z
)
/
ne12
;
const
int
i12
=
(
blockIdx
.
z
*
blockDim
.
z
+
threadIdx
.
z
)
%
ne12
;
if
(
i00
>=
ne00
)
{
return
;
}
const
int
i01
=
src1
[
i10
*
s10
+
i11
*
s11
+
i12
*
s12
];
dst_t
*
dst_row
=
dst
+
i10
*
s1
+
i11
*
s2
+
i12
*
s3
;
const
src0_t
*
src0_row
=
(
const
src0_t
*
)((
const
char
*
)
src0
+
i01
*
nb01
+
i11
*
nb02
+
i12
*
nb03
);
dst_row
[
i00
]
=
src0_row
[
i00
];
}
template
<
int
qk
,
int
qr
,
dequantize_kernel_t
dq
>
static
void
get_rows_cuda
(
const
ggml_tensor
*
src0
,
const
ggml_tensor
*
src1
,
ggml_tensor
*
dst
,
const
void
*
src0_dd
,
const
int32_t
*
src1_dd
,
float
*
dst_dd
,
cudaStream_t
stream
)
{
GGML_TENSOR_BINARY_OP_LOCALS
const
dim3
block_dims
(
CUDA_GET_ROWS_BLOCK_SIZE
,
1
,
1
);
const
int
block_num_x
=
(
ne00
+
2
*
CUDA_GET_ROWS_BLOCK_SIZE
-
1
)
/
(
2
*
CUDA_GET_ROWS_BLOCK_SIZE
);
const
dim3
block_nums
(
block_num_x
,
ne10
,
ne11
*
ne12
);
// strides in elements
//const size_t s0 = nb0 / ggml_element_size(dst);
const
size_t
s1
=
nb1
/
ggml_element_size
(
dst
);
const
size_t
s2
=
nb2
/
ggml_element_size
(
dst
);
const
size_t
s3
=
nb3
/
ggml_element_size
(
dst
);
const
size_t
s10
=
nb10
/
ggml_element_size
(
src1
);
const
size_t
s11
=
nb11
/
ggml_element_size
(
src1
);
const
size_t
s12
=
nb12
/
ggml_element_size
(
src1
);
//const size_t s13 = nb13 / ggml_element_size(src1);
GGML_ASSERT
(
ne00
%
2
==
0
);
k_get_rows
<
qk
,
qr
,
dq
><<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
src0_dd
,
src1_dd
,
dst_dd
,
ne00
,
/*ne01, ne02, ne03,*/
/*ne10, ne11,*/
ne12
,
/*ne13,*/
/* s0,*/
s1
,
s2
,
s3
,
/* nb00,*/
nb01
,
nb02
,
nb03
,
s10
,
s11
,
s12
/*, s13*/
);
GGML_UNUSED
(
dst
);
}
template
<
typename
src0_t
>
static
void
get_rows_cuda_float
(
const
ggml_tensor
*
src0
,
const
ggml_tensor
*
src1
,
ggml_tensor
*
dst
,
const
src0_t
*
src0_dd
,
const
int32_t
*
src1_dd
,
float
*
dst_dd
,
cudaStream_t
stream
)
{
GGML_TENSOR_BINARY_OP_LOCALS
const
dim3
block_dims
(
CUDA_GET_ROWS_BLOCK_SIZE
,
1
,
1
);
const
int
block_num_x
=
(
ne00
+
CUDA_GET_ROWS_BLOCK_SIZE
-
1
)
/
CUDA_GET_ROWS_BLOCK_SIZE
;
const
dim3
block_nums
(
block_num_x
,
ne10
,
ne11
*
ne12
);
// strides in elements
//const size_t s0 = nb0 / ggml_element_size(dst);
const
size_t
s1
=
nb1
/
ggml_element_size
(
dst
);
const
size_t
s2
=
nb2
/
ggml_element_size
(
dst
);
const
size_t
s3
=
nb3
/
ggml_element_size
(
dst
);
const
size_t
s10
=
nb10
/
ggml_element_size
(
src1
);
const
size_t
s11
=
nb11
/
ggml_element_size
(
src1
);
const
size_t
s12
=
nb12
/
ggml_element_size
(
src1
);
//const size_t s13 = nb13 / ggml_element_size(src1);
k_get_rows_float
<<<
block_nums
,
block_dims
,
0
,
stream
>>>
(
src0_dd
,
src1_dd
,
dst_dd
,
ne00
,
/*ne01, ne02, ne03,*/
/*ne10, ne11,*/
ne12
,
/*ne13,*/
/* s0,*/
s1
,
s2
,
s3
,
/* nb00,*/
nb01
,
nb02
,
nb03
,
s10
,
s11
,
s12
/*, s13*/
);
GGML_UNUSED
(
dst
);
}
void
ggml_cuda_op_get_rows
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
src0
=
dst
->
src
[
0
];
const
ggml_tensor
*
src1
=
dst
->
src
[
1
];
const
float
*
src0_d
=
(
const
float
*
)
src0
->
data
;
const
float
*
src1_d
=
(
const
float
*
)
src1
->
data
;
float
*
dst_d
=
(
float
*
)
dst
->
data
;
cudaStream_t
stream
=
ctx
.
stream
();
GGML_ASSERT
(
src1
->
type
==
GGML_TYPE_I32
);
GGML_ASSERT
(
dst
->
type
==
GGML_TYPE_F32
);
GGML_ASSERT
(
src0
->
nb
[
0
]
==
ggml_type_size
(
src0
->
type
));
GGML_ASSERT
(
src1
->
nb
[
0
]
==
ggml_type_size
(
src1
->
type
));
GGML_ASSERT
(
dst
->
nb
[
0
]
==
ggml_type_size
(
dst
->
type
));
const
int32_t
*
src1_i32
=
(
const
int32_t
*
)
src1_d
;
switch
(
src0
->
type
)
{
case
GGML_TYPE_F16
:
get_rows_cuda_float
(
src0
,
src1
,
dst
,
(
const
half
*
)
src0_d
,
src1_i32
,
dst_d
,
stream
);
break
;
case
GGML_TYPE_F32
:
get_rows_cuda_float
(
src0
,
src1
,
dst
,
src0_d
,
src1_i32
,
dst_d
,
stream
);
break
;
case
GGML_TYPE_Q4_0
:
get_rows_cuda
<
QK4_0
,
QR4_0
,
dequantize_q4_0
>
(
src0
,
src1
,
dst
,
src0_d
,
src1_i32
,
dst_d
,
stream
);
break
;
case
GGML_TYPE_Q4_1
:
get_rows_cuda
<
QK4_1
,
QR4_1
,
dequantize_q4_1
>
(
src0
,
src1
,
dst
,
src0_d
,
src1_i32
,
dst_d
,
stream
);
break
;
case
GGML_TYPE_Q5_0
:
get_rows_cuda
<
QK5_0
,
QR5_0
,
dequantize_q5_0
>
(
src0
,
src1
,
dst
,
src0_d
,
src1_i32
,
dst_d
,
stream
);
break
;
case
GGML_TYPE_Q5_1
:
get_rows_cuda
<
QK5_1
,
QR5_1
,
dequantize_q5_1
>
(
src0
,
src1
,
dst
,
src0_d
,
src1_i32
,
dst_d
,
stream
);
break
;
case
GGML_TYPE_Q8_0
:
get_rows_cuda
<
QK8_0
,
QR8_0
,
dequantize_q8_0
>
(
src0
,
src1
,
dst
,
src0_d
,
src1_i32
,
dst_d
,
stream
);
break
;
default:
// TODO: k-quants
fprintf
(
stderr
,
"%s: unsupported type: %s
\n
"
,
__func__
,
ggml_type_name
(
src0
->
type
));
GGML_ASSERT
(
false
);
break
;
}
}
llm/llama.cpp/ggml-cuda/getrows.cuh
deleted
100644 → 0
View file @
97b02a89
#include "common.cuh"
#define CUDA_GET_ROWS_BLOCK_SIZE 256
void
ggml_cuda_op_get_rows
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment