Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
544b6739
"...resnet50_tensorflow.git" did not exist on "5b6be76bc8be7d4ffb60869896505833f1fe5c9a"
Unverified
Commit
544b6739
authored
Nov 06, 2025
by
Daniel Hiltgen
Committed by
GitHub
Nov 06, 2025
Browse files
ggml update to b6840 (#12791)
parent
c4ba257c
Changes
103
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1529 additions
and
266 deletions
+1529
-266
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h
+1
-0
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m
+22
-11
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
+254
-65
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
+14
-0
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp
+75
-1
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h
+1
-0
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
+240
-65
ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt
ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt
+9
-0
ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp
ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp
+509
-93
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
...gml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
+14
-0
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
.../ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
+1
-1
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
.../ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
+19
-1
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
...l/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
+1
-1
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
...l/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
+5
-1
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
...ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
+1
-1
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
...kend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
+30
-20
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp
...nd/ggml/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp
+44
-0
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
...nd/ggml/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
+140
-0
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
...nd/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
+139
-0
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
...gml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
+10
-6
No files found.
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h
View file @
544b6739
...
@@ -130,6 +130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_me
...
@@ -130,6 +130,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm (ggml_me
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_rope
(
ggml_metal_library_t
lib
,
const
struct
ggml_tensor
*
op
);
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_rope
(
ggml_metal_library_t
lib
,
const
struct
ggml_tensor
*
op
);
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_im2col
(
ggml_metal_library_t
lib
,
const
struct
ggml_tensor
*
op
);
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_im2col
(
ggml_metal_library_t
lib
,
const
struct
ggml_tensor
*
op
);
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_conv_transpose_1d
(
ggml_metal_library_t
lib
,
const
struct
ggml_tensor
*
op
);
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_conv_transpose_1d
(
ggml_metal_library_t
lib
,
const
struct
ggml_tensor
*
op
);
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_conv_transpose_2d
(
ggml_metal_library_t
lib
,
const
struct
ggml_tensor
*
op
);
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_upscale
(
ggml_metal_library_t
lib
,
const
struct
ggml_tensor
*
op
);
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_upscale
(
ggml_metal_library_t
lib
,
const
struct
ggml_tensor
*
op
);
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_pad
(
ggml_metal_library_t
lib
,
const
struct
ggml_tensor
*
op
);
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_pad
(
ggml_metal_library_t
lib
,
const
struct
ggml_tensor
*
op
);
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_pad_reflect_1d
(
ggml_metal_library_t
lib
,
const
struct
ggml_tensor
*
op
);
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_pad_reflect_1d
(
ggml_metal_library_t
lib
,
const
struct
ggml_tensor
*
op
);
...
...
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m
View file @
544b6739
...
@@ -7,6 +7,8 @@
...
@@ -7,6 +7,8 @@
#include <Metal/Metal.h>
#include <Metal/Metal.h>
#include <stdatomic.h>
#ifndef TARGET_OS_VISION
#ifndef TARGET_OS_VISION
#define TARGET_OS_VISION 0
#define TARGET_OS_VISION 0
#endif
#endif
...
@@ -22,6 +24,9 @@
...
@@ -22,6 +24,9 @@
// overload of MTLGPUFamilyMetal3 (not available in some environments)
// overload of MTLGPUFamilyMetal3 (not available in some environments)
static
const
NSInteger
MTLGPUFamilyMetal3_GGML
=
5001
;
static
const
NSInteger
MTLGPUFamilyMetal3_GGML
=
5001
;
// virtual address for GPU memory allocations
static
atomic_uintptr_t
g_addr_device
=
0x000000400ULL
;
#if !GGML_METAL_EMBED_LIBRARY
#if !GGML_METAL_EMBED_LIBRARY
// Here to assist with NSBundle Path Hack
// Here to assist with NSBundle Path Hack
@interface
GGMLMetalClass
:
NSObject
@interface
GGMLMetalClass
:
NSObject
...
@@ -648,6 +653,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
...
@@ -648,6 +653,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case
GGML_OP_SCALE
:
case
GGML_OP_SCALE
:
case
GGML_OP_CONV_TRANSPOSE_1D
:
case
GGML_OP_CONV_TRANSPOSE_1D
:
return
true
;
return
true
;
case
GGML_OP_CONV_TRANSPOSE_2D
:
return
ggml_is_contiguous
(
op
->
src
[
0
])
&&
ggml_is_contiguous
(
op
->
src
[
1
])
&&
(
op
->
src
[
0
]
->
type
==
GGML_TYPE_F16
||
op
->
src
[
0
]
->
type
==
GGML_TYPE_F32
)
&&
op
->
src
[
1
]
->
type
==
GGML_TYPE_F32
&&
op
->
type
==
GGML_TYPE_F32
;
case
GGML_OP_CLAMP
:
case
GGML_OP_CLAMP
:
return
op
->
src
[
0
]
->
type
==
GGML_TYPE_F32
;
return
op
->
src
[
0
]
->
type
==
GGML_TYPE_F32
;
case
GGML_OP_SQR
:
case
GGML_OP_SQR
:
...
@@ -657,6 +667,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
...
@@ -657,6 +667,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case
GGML_OP_LOG
:
case
GGML_OP_LOG
:
return
ggml_is_contiguous
(
op
->
src
[
0
])
&&
op
->
src
[
0
]
->
type
==
GGML_TYPE_F32
;
return
ggml_is_contiguous
(
op
->
src
[
0
])
&&
op
->
src
[
0
]
->
type
==
GGML_TYPE_F32
;
case
GGML_OP_SUM
:
case
GGML_OP_SUM
:
return
has_simdgroup_reduction
&&
ggml_is_contiguous
(
op
->
src
[
0
]);
case
GGML_OP_SUM_ROWS
:
case
GGML_OP_SUM_ROWS
:
case
GGML_OP_MEAN
:
case
GGML_OP_MEAN
:
case
GGML_OP_SOFT_MAX
:
case
GGML_OP_SOFT_MAX
:
...
@@ -693,7 +704,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
...
@@ -693,7 +704,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return
true
;
return
true
;
case
GGML_OP_FLASH_ATTN_EXT
:
case
GGML_OP_FLASH_ATTN_EXT
:
// for new head sizes, add checks here
// for new head sizes, add checks here
if
(
op
->
src
[
0
]
->
ne
[
0
]
!=
40
&&
if
(
op
->
src
[
0
]
->
ne
[
0
]
!=
32
&&
op
->
src
[
0
]
->
ne
[
0
]
!=
40
&&
op
->
src
[
0
]
->
ne
[
0
]
!=
64
&&
op
->
src
[
0
]
->
ne
[
0
]
!=
64
&&
op
->
src
[
0
]
->
ne
[
0
]
!=
80
&&
op
->
src
[
0
]
->
ne
[
0
]
!=
80
&&
op
->
src
[
0
]
->
ne
[
0
]
!=
96
&&
op
->
src
[
0
]
->
ne
[
0
]
!=
96
&&
...
@@ -826,7 +838,7 @@ struct ggml_metal_buffer_wrapper {
...
@@ -826,7 +838,7 @@ struct ggml_metal_buffer_wrapper {
};
};
struct
ggml_metal_buffer
{
struct
ggml_metal_buffer
{
void
*
all_data
;
// TODO: https://github.com/ggml-org/llama.cpp/pull/15985
void
*
all_data
;
size_t
all_size
;
size_t
all_size
;
// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
// if false, the Metal buffer data is allocated in private GPU memory and is not shared with the host
...
@@ -964,14 +976,15 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
...
@@ -964,14 +976,15 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
if
(
shared
)
{
if
(
shared
)
{
res
->
all_data
=
ggml_metal_host_malloc
(
size_aligned
);
res
->
all_data
=
ggml_metal_host_malloc
(
size_aligned
);
res
->
is_shared
=
true
;
res
->
is_shared
=
true
;
res
->
owned
=
true
;
}
else
{
}
else
{
//
dummy, non-NULL value - we'll populate this after creating the Metal buffer below
//
use virtual address from g_addr_device counter
res
->
all_data
=
(
void
*
)
0x000000400ULL
;
res
->
all_data
=
(
void
*
)
atomic_fetch_add_explicit
(
&
g_addr_device
,
size_aligned
,
memory_order_relaxed
)
;
res
->
is_shared
=
false
;
res
->
is_shared
=
false
;
}
}
res
->
all_size
=
size_aligned
;
res
->
all_size
=
size_aligned
;
res
->
owned
=
true
;
res
->
device
=
ggml_metal_device_get_obj
(
dev
);
res
->
device
=
ggml_metal_device_get_obj
(
dev
);
res
->
queue
=
ggml_metal_device_get_queue
(
dev
);
res
->
queue
=
ggml_metal_device_get_queue
(
dev
);
...
@@ -982,15 +995,13 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
...
@@ -982,15 +995,13 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size,
res
->
buffers
[
0
].
metal
=
nil
;
res
->
buffers
[
0
].
metal
=
nil
;
if
(
size_aligned
>
0
)
{
if
(
size_aligned
>
0
)
{
if
(
props_dev
->
use_shared_buffers
&&
shared
)
{
if
(
props_dev
->
use_shared_buffers
&&
shared
)
{
res
->
buffers
[
0
].
metal
=
[
res
->
device
newBufferWithBytesNoCopy
:
res
->
all_data
res
->
buffers
[
0
].
metal
=
[
res
->
device
newBufferWithBytesNoCopy
:
res
->
all_data
length:
size_aligned
length:
size_aligned
options:
MTLResourceStorageModeShared
options:
MTLResourceStorageModeShared
deallocator:
nil
];
deallocator:
nil
];
}
else
{
}
else
{
res
->
buffers
[
0
].
metal
=
[
res
->
device
newBufferWithLength
:
size_aligned
options
:
MTLResourceStorageModePrivate
];
res
->
buffers
[
0
].
metal
=
[
res
->
device
newBufferWithLength
:
size_aligned
options
:
MTLResourceStorageModePrivate
];
res
->
all_data
=
(
void
*
)
(
res
->
buffers
[
0
].
metal
.
gpuAddress
);
}
}
}
}
...
@@ -1138,7 +1149,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) {
...
@@ -1138,7 +1149,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf) {
void
ggml_metal_buffer_memset_tensor
(
ggml_metal_buffer_t
buf
,
struct
ggml_tensor
*
tensor
,
uint8_t
value
,
size_t
offset
,
size_t
size
)
{
void
ggml_metal_buffer_memset_tensor
(
ggml_metal_buffer_t
buf
,
struct
ggml_tensor
*
tensor
,
uint8_t
value
,
size_t
offset
,
size_t
size
)
{
if
(
buf
->
is_shared
)
{
if
(
buf
->
is_shared
)
{
memset
((
char
*
)
tensor
->
data
+
offset
,
value
,
size
);
memset
((
char
*
)
tensor
->
data
+
offset
,
value
,
size
);
return
;
return
;
}
}
...
@@ -1167,7 +1178,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor
...
@@ -1167,7 +1178,7 @@ void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor
void
ggml_metal_buffer_set_tensor
(
ggml_metal_buffer_t
buf
,
struct
ggml_tensor
*
tensor
,
const
void
*
data
,
size_t
offset
,
size_t
size
)
{
void
ggml_metal_buffer_set_tensor
(
ggml_metal_buffer_t
buf
,
struct
ggml_tensor
*
tensor
,
const
void
*
data
,
size_t
offset
,
size_t
size
)
{
if
(
buf
->
is_shared
)
{
if
(
buf
->
is_shared
)
{
memcpy
((
char
*
)
tensor
->
data
+
offset
,
data
,
size
);
memcpy
((
char
*
)
tensor
->
data
+
offset
,
data
,
size
);
return
;
return
;
}
}
...
@@ -1222,7 +1233,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *
...
@@ -1222,7 +1233,7 @@ void ggml_metal_buffer_set_tensor(ggml_metal_buffer_t buf, struct ggml_tensor *
void
ggml_metal_buffer_get_tensor
(
ggml_metal_buffer_t
buf
,
const
struct
ggml_tensor
*
tensor
,
void
*
data
,
size_t
offset
,
size_t
size
)
{
void
ggml_metal_buffer_get_tensor
(
ggml_metal_buffer_t
buf
,
const
struct
ggml_tensor
*
tensor
,
void
*
data
,
size_t
offset
,
size_t
size
)
{
if
(
buf
->
is_shared
)
{
if
(
buf
->
is_shared
)
{
memcpy
(
data
,
(
const
char
*
)
tensor
->
data
+
offset
,
size
);
memcpy
(
data
,
(
const
char
*
)
tensor
->
data
+
offset
,
size
);
return
;
return
;
}
}
...
...
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
View file @
544b6739
...
@@ -2136,6 +2136,7 @@ typedef struct {
...
@@ -2136,6 +2136,7 @@ typedef struct {
int32_t sect_1;
int32_t sect_1;
int32_t sect_2;
int32_t sect_2;
int32_t sect_3;
int32_t sect_3;
bool src2;
} ggml_metal_kargs_rope;
} ggml_metal_kargs_rope;
typedef struct {
typedef struct {
...
@@ -2398,6 +2399,19 @@ typedef struct {
...
@@ -2398,6 +2399,19 @@ typedef struct {
uint64_t nb1;
uint64_t nb1;
} ggml_metal_kargs_conv_transpose_1d;
} ggml_metal_kargs_conv_transpose_1d;
typedef struct {
int32_t IC;
int32_t IH;
int32_t IW;
int32_t KH;
int32_t KW;
int32_t OC;
int32_t s0;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_conv_transpose_2d;
typedef struct {
typedef struct {
uint64_t ofs0;
uint64_t ofs0;
uint64_t ofs1;
uint64_t ofs1;
...
@@ -4392,18 +4406,48 @@ kernel void kernel_op_sum_f32(
...
@@ -4392,18 +4406,48 @@ kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
constant ggml_metal_kargs_sum & args,
device const float * src0,
device const float * src0,
device float * dst,
device float * dst,
ushort tiitg[[thread_index_in_threadgroup]]) {
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (
tiitg !
= 0) {
if (
args.np =
= 0) {
return;
return;
}
}
float acc = 0.0f;
const uint nsg = (ntg.x + 31) / 32;
for (ulong i = 0; i < args.np; ++i) {
acc += src0[i];
float sumf = 0;
for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
sumf += src0[i0];
}
sumf = simd_sum(sumf);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
}
dst[0] = acc;
threadgroup_barrier(mem_flags::mem_threadgroup);
float total = 0;
if (sgitg == 0) {
float v = 0;
if (tpitg.x < nsg) {
v = shmem_f32[tpitg.x];
}
total = simd_sum(v);
if (tpitg.x == 0) {
dst[0] = total;
}
}
}
}
template <bool norm>
template <bool norm>
...
@@ -6413,7 +6457,7 @@ kernel void kernel_rope_norm(
...
@@ -6413,7 +6457,7 @@ kernel void kernel_rope_norm(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor =
src2 != src0
? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor =
args.src2
? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
...
@@ -6466,7 +6510,7 @@ kernel void kernel_rope_neox(
...
@@ -6466,7 +6510,7 @@ kernel void kernel_rope_neox(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor =
src2 != src0
? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor =
args.src2
? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
...
@@ -6533,7 +6577,7 @@ kernel void kernel_rope_multi(
...
@@ -6533,7 +6577,7 @@ kernel void kernel_rope_multi(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor =
src2 != src0
? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor =
args.src2
? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
...
@@ -6600,7 +6644,7 @@ kernel void kernel_rope_vision(
...
@@ -6600,7 +6644,7 @@ kernel void kernel_rope_vision(
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
// end of mrope
// end of mrope
const float freq_factor =
src2 != src0
? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor =
args.src2
? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
...
@@ -6810,6 +6854,97 @@ kernel void kernel_conv_transpose_1d<half>(
...
@@ -6810,6 +6854,97 @@ kernel void kernel_conv_transpose_1d<half>(
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
uint3 tgpg[[threadgroups_per_grid]]);
typedef void (conv_transpose_2d_t)(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_2d(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const T * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t out_x = tgpig[0];
const int64_t out_y = tgpig[1];
const int64_t out_c = tgpig[2];
const int64_t kw = tpitg[0];
const int64_t kh = tpitg[1];
float v = 0.0f;
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
int64_t in_y = out_y - kh;
if (in_y < 0 || in_y % args.s0) continue;
in_y /= args.s0;
if (in_y >= args.IH) continue;
int64_t in_x = out_x - kw;
if (in_x < 0 || in_x % args.s0) continue;
in_x /= args.s0;
if (in_x >= args.IW) continue;
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
v += (float)src0[kernel_idx] * src1[input_idx];
}
const uint tid = tpitg.y * ntg.x + tpitg.x;
shared_sum[tid] = v;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float total = 0.0f;
const uint num_threads = ntg.x * ntg.y;
for (uint i = 0; i < num_threads; i++) {
total += shared_sum[i];
}
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
dst_ptr[0] = total;
}
}
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
kernel void kernel_conv_transpose_2d<float>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
kernel void kernel_conv_transpose_2d<half>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const half * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
kernel void kernel_upscale_f32(
kernel void kernel_upscale_f32(
constant ggml_metal_kargs_upscale & args,
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device const char * src0,
...
@@ -7938,8 +8073,30 @@ kernel void kernel_flash_attn_ext(
...
@@ -7938,8 +8073,30 @@ kernel void kernel_flash_attn_ext(
half, half4, simdgroup_half8x8
half, half4, simdgroup_half8x8
//float, float4, simdgroup_float8x8
//float, float4, simdgroup_float8x8
#define FA_TYPES_F32 \
half, half4, simdgroup_half8x8, \
float, float4x4, simdgroup_float8x8, \
float, float4x4, simdgroup_float8x8, \
float, simdgroup_float8x8, \
float, float2, simdgroup_float8x8, \
float, float4, simdgroup_float8x8
//half, half4, simdgroup_half8x8
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
...
@@ -7952,6 +8109,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
...
@@ -7952,6 +8109,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
...
@@ -7964,6 +8122,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
...
@@ -7964,6 +8122,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
...
@@ -7975,6 +8134,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
...
@@ -7975,6 +8134,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
...
@@ -7986,6 +8146,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
...
@@ -7986,6 +8146,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
...
@@ -7997,6 +8158,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
...
@@ -7997,6 +8158,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
...
@@ -8008,6 +8170,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
...
@@ -8008,6 +8170,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
...
@@ -8543,77 +8706,103 @@ kernel void kernel_flash_attn_ext_vec(
...
@@ -8543,77 +8706,103 @@ kernel void kernel_flash_attn_ext_vec(
float, float4, \
float, float4, \
float4
float4
#define FA_TYPES_F32 \
half4, \
float4, \
float4, \
float, \
float, float4, \
float4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk
64
_dv
64
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
64, 64
,
2
>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk
32
_dv
32
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
32, 32
,
4
>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk
9
6_dv
9
6")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
96, 96
,
4
>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk6
4
_dv6
4
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
64, 64
,
2
>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk
128_dv128
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
128, 128
,
1
>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk
96_dv96
")]]
kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
96, 96
,
4
>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk1
9
2_dv1
9
2")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 1
9
2, 1
9
2,
2
>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk12
8
_dv12
8
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 12
8
, 12
8
,
1
>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv12
8
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 12
8
, 2>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv1
9
2")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 1
9
2, 2>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk
256_dv256
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
256, 256
,
1
>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk
192_dv128
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
192, 128
,
2
>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
#undef FA_TYPES
#undef FA_TYPES
...
...
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
View file @
544b6739
...
@@ -251,6 +251,7 @@ typedef struct {
...
@@ -251,6 +251,7 @@ typedef struct {
int32_t
sect_1
;
int32_t
sect_1
;
int32_t
sect_2
;
int32_t
sect_2
;
int32_t
sect_3
;
int32_t
sect_3
;
bool
src2
;
}
ggml_metal_kargs_rope
;
}
ggml_metal_kargs_rope
;
typedef
struct
{
typedef
struct
{
...
@@ -513,6 +514,19 @@ typedef struct {
...
@@ -513,6 +514,19 @@ typedef struct {
uint64_t
nb1
;
uint64_t
nb1
;
}
ggml_metal_kargs_conv_transpose_1d
;
}
ggml_metal_kargs_conv_transpose_1d
;
typedef
struct
{
int32_t
IC
;
int32_t
IH
;
int32_t
IW
;
int32_t
KH
;
int32_t
KW
;
int32_t
OC
;
int32_t
s0
;
uint64_t
nb0
;
uint64_t
nb1
;
uint64_t
nb2
;
}
ggml_metal_kargs_conv_transpose_2d
;
typedef
struct
{
typedef
struct
{
uint64_t
ofs0
;
uint64_t
ofs0
;
uint64_t
ofs1
;
uint64_t
ofs1
;
...
...
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp
View file @
544b6739
...
@@ -368,6 +368,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
...
@@ -368,6 +368,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
{
n_fuse
=
ggml_metal_op_conv_transpose_1d
(
ctx
,
idx
);
n_fuse
=
ggml_metal_op_conv_transpose_1d
(
ctx
,
idx
);
}
break
;
}
break
;
case
GGML_OP_CONV_TRANSPOSE_2D
:
{
n_fuse
=
ggml_metal_op_conv_transpose_2d
(
ctx
,
idx
);
}
break
;
case
GGML_OP_UPSCALE
:
case
GGML_OP_UPSCALE
:
{
{
n_fuse
=
ggml_metal_op_upscale
(
ctx
,
idx
);
n_fuse
=
ggml_metal_op_upscale
(
ctx
,
idx
);
...
@@ -866,12 +870,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
...
@@ -866,12 +870,25 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t
pipeline
=
ggml_metal_library_get_pipeline_sum
(
lib
,
op
);
ggml_metal_pipeline_t
pipeline
=
ggml_metal_library_get_pipeline_sum
(
lib
,
op
);
int
nth
=
32
;
// SIMD width
while
(
nth
<
(
int
)
n
&&
nth
<
ggml_metal_pipeline_max_theads_per_threadgroup
(
pipeline
))
{
nth
*=
2
;
}
nth
=
std
::
min
(
nth
,
ggml_metal_pipeline_max_theads_per_threadgroup
(
pipeline
));
nth
=
std
::
min
(
nth
,
(
int
)
n
);
const
int
nsg
=
(
nth
+
31
)
/
32
;
ggml_metal_encoder_set_pipeline
(
enc
,
pipeline
);
ggml_metal_encoder_set_pipeline
(
enc
,
pipeline
);
ggml_metal_encoder_set_bytes
(
enc
,
&
args
,
sizeof
(
args
),
0
);
ggml_metal_encoder_set_bytes
(
enc
,
&
args
,
sizeof
(
args
),
0
);
ggml_metal_encoder_set_buffer
(
enc
,
ggml_metal_get_buffer_id
(
op
->
src
[
0
]),
1
);
ggml_metal_encoder_set_buffer
(
enc
,
ggml_metal_get_buffer_id
(
op
->
src
[
0
]),
1
);
ggml_metal_encoder_set_buffer
(
enc
,
ggml_metal_get_buffer_id
(
op
),
2
);
ggml_metal_encoder_set_buffer
(
enc
,
ggml_metal_get_buffer_id
(
op
),
2
);
ggml_metal_encoder_dispatch_threadgroups
(
enc
,
1
,
1
,
1
,
1
,
1
,
1
);
ggml_metal_encoder_set_threadgroup_memory_size
(
enc
,
nsg
*
sizeof
(
float
),
0
);
ggml_metal_encoder_dispatch_threadgroups
(
enc
,
1
,
1
,
1
,
nth
,
1
,
1
);
return
1
;
return
1
;
}
}
...
@@ -2969,6 +2986,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
...
@@ -2969,6 +2986,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
/* sect_1 =*/
sect_1
,
/* sect_1 =*/
sect_1
,
/* sect_2 =*/
sect_2
,
/* sect_2 =*/
sect_2
,
/* sect_3 =*/
sect_3
,
/* sect_3 =*/
sect_3
,
/* src2 =*/
op
->
src
[
2
]
!=
nullptr
,
};
};
ggml_metal_pipeline_t
pipeline
=
ggml_metal_library_get_pipeline_rope
(
lib
,
op
);
ggml_metal_pipeline_t
pipeline
=
ggml_metal_library_get_pipeline_rope
(
lib
,
op
);
...
@@ -3104,6 +3122,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
...
@@ -3104,6 +3122,62 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
return
1
;
return
1
;
}
}
int
ggml_metal_op_conv_transpose_2d
(
ggml_metal_op_t
ctx
,
int
idx
)
{
ggml_tensor
*
op
=
ctx
->
node
(
idx
);
ggml_metal_library_t
lib
=
ctx
->
lib
;
ggml_metal_encoder_t
enc
=
ctx
->
enc
;
GGML_TENSOR_LOCALS
(
int32_t
,
ne0
,
op
->
src
[
0
],
ne
);
GGML_TENSOR_LOCALS
(
uint64_t
,
nb0
,
op
->
src
[
0
],
nb
);
GGML_TENSOR_LOCALS
(
int32_t
,
ne1
,
op
->
src
[
1
],
ne
);
GGML_TENSOR_LOCALS
(
uint64_t
,
nb1
,
op
->
src
[
1
],
nb
);
GGML_TENSOR_LOCALS
(
int32_t
,
ne
,
op
,
ne
);
GGML_TENSOR_LOCALS
(
uint32_t
,
nb
,
op
,
nb
);
const
int32_t
s0
=
((
const
int32_t
*
)(
op
->
op_params
))[
0
];
const
int32_t
IC
=
op
->
src
[
1
]
->
ne
[
2
];
const
int32_t
IH
=
op
->
src
[
1
]
->
ne
[
1
];
const
int32_t
IW
=
op
->
src
[
1
]
->
ne
[
0
];
const
int32_t
KH
=
op
->
src
[
0
]
->
ne
[
1
];
const
int32_t
KW
=
op
->
src
[
0
]
->
ne
[
0
];
const
int32_t
OW
=
op
->
ne
[
0
];
const
int32_t
OH
=
op
->
ne
[
1
];
const
int32_t
OC
=
op
->
ne
[
2
];
ggml_metal_kargs_conv_transpose_2d
args
=
{
/*.IC =*/
IC
,
/*.IH =*/
IH
,
/*.IW =*/
IW
,
/*.KH =*/
KH
,
/*.KW =*/
KW
,
/*.OC =*/
OC
,
/*.s0 =*/
s0
,
/*.nb0 =*/
nb0
,
/*.nb1 =*/
nb1
,
/*.nb2 =*/
nb2
,
};
ggml_metal_pipeline_t
pipeline
=
ggml_metal_library_get_pipeline_conv_transpose_2d
(
lib
,
op
);
ggml_metal_encoder_set_pipeline
(
enc
,
pipeline
);
ggml_metal_encoder_set_bytes
(
enc
,
&
args
,
sizeof
(
args
),
0
);
ggml_metal_encoder_set_buffer
(
enc
,
ggml_metal_get_buffer_id
(
op
->
src
[
0
]),
1
);
ggml_metal_encoder_set_buffer
(
enc
,
ggml_metal_get_buffer_id
(
op
->
src
[
1
]),
2
);
ggml_metal_encoder_set_buffer
(
enc
,
ggml_metal_get_buffer_id
(
op
),
3
);
// Metal requires buffer size to be multiple of 16 bytes
const
size_t
smem
=
GGML_PAD
(
KW
*
KH
*
sizeof
(
float
),
16
);
ggml_metal_encoder_set_threadgroup_memory_size
(
enc
,
smem
,
0
);
ggml_metal_encoder_dispatch_threadgroups
(
enc
,
OW
,
OH
,
OC
,
KW
,
KH
,
1
);
return
1
;
}
int
ggml_metal_op_upscale
(
ggml_metal_op_t
ctx
,
int
idx
)
{
int
ggml_metal_op_upscale
(
ggml_metal_op_t
ctx
,
int
idx
)
{
ggml_tensor
*
op
=
ctx
->
node
(
idx
);
ggml_tensor
*
op
=
ctx
->
node
(
idx
);
...
...
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h
View file @
544b6739
...
@@ -71,6 +71,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
...
@@ -71,6 +71,7 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx);
int
ggml_metal_op_rope
(
ggml_metal_op_t
ctx
,
int
idx
);
int
ggml_metal_op_rope
(
ggml_metal_op_t
ctx
,
int
idx
);
int
ggml_metal_op_im2col
(
ggml_metal_op_t
ctx
,
int
idx
);
int
ggml_metal_op_im2col
(
ggml_metal_op_t
ctx
,
int
idx
);
int
ggml_metal_op_conv_transpose_1d
(
ggml_metal_op_t
ctx
,
int
idx
);
int
ggml_metal_op_conv_transpose_1d
(
ggml_metal_op_t
ctx
,
int
idx
);
int
ggml_metal_op_conv_transpose_2d
(
ggml_metal_op_t
ctx
,
int
idx
);
int
ggml_metal_op_upscale
(
ggml_metal_op_t
ctx
,
int
idx
);
int
ggml_metal_op_upscale
(
ggml_metal_op_t
ctx
,
int
idx
);
int
ggml_metal_op_pad
(
ggml_metal_op_t
ctx
,
int
idx
);
int
ggml_metal_op_pad
(
ggml_metal_op_t
ctx
,
int
idx
);
int
ggml_metal_op_pad_reflect_1d
(
ggml_metal_op_t
ctx
,
int
idx
);
int
ggml_metal_op_pad_reflect_1d
(
ggml_metal_op_t
ctx
,
int
idx
);
...
...
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
View file @
544b6739
...
@@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32(
...
@@ -1727,18 +1727,48 @@ kernel void kernel_op_sum_f32(
constant ggml_metal_kargs_sum & args,
constant ggml_metal_kargs_sum & args,
device const float * src0,
device const float * src0,
device float * dst,
device float * dst,
ushort tiitg[[thread_index_in_threadgroup]]) {
threadgroup float * shmem_f32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
if (
tiitg !
= 0) {
if (
args.np =
= 0) {
return;
return;
}
}
float acc = 0.0f;
const uint nsg = (ntg.x + 31) / 32;
for (ulong i = 0; i < args.np; ++i) {
acc += src0[i];
float sumf = 0;
for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) {
sumf += src0[i0];
}
}
dst[0] = acc;
sumf = simd_sum(sumf);
if (tiisg == 0) {
shmem_f32[sgitg] = sumf;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float total = 0;
if (sgitg == 0) {
float v = 0;
if (tpitg.x < nsg) {
v = shmem_f32[tpitg.x];
}
total = simd_sum(v);
if (tpitg.x == 0) {
dst[0] = total;
}
}
}
}
template <bool norm>
template <bool norm>
...
@@ -3748,7 +3778,7 @@ kernel void kernel_rope_norm(
...
@@ -3748,7 +3778,7 @@ kernel void kernel_rope_norm(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor =
src2 != src0
? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor =
args.src2
? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
...
@@ -3801,7 +3831,7 @@ kernel void kernel_rope_neox(
...
@@ -3801,7 +3831,7 @@ kernel void kernel_rope_neox(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor =
src2 != src0
? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor =
args.src2
? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
...
@@ -3868,7 +3898,7 @@ kernel void kernel_rope_multi(
...
@@ -3868,7 +3898,7 @@ kernel void kernel_rope_multi(
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor =
src2 != src0
? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor =
args.src2
? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
...
@@ -3935,7 +3965,7 @@ kernel void kernel_rope_vision(
...
@@ -3935,7 +3965,7 @@ kernel void kernel_rope_vision(
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
// end of mrope
// end of mrope
const float freq_factor =
src2 != src0
? ((device const float *) src2)[ic] : 1.0f;
const float freq_factor =
args.src2
? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
...
@@ -4145,6 +4175,97 @@ kernel void kernel_conv_transpose_1d<half>(
...
@@ -4145,6 +4175,97 @@ kernel void kernel_conv_transpose_1d<half>(
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
uint3 tgpg[[threadgroups_per_grid]]);
typedef void (conv_transpose_2d_t)(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tgpg[[threadgroups_per_grid]]);
template <typename T>
kernel void kernel_conv_transpose_2d(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const T * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t out_x = tgpig[0];
const int64_t out_y = tgpig[1];
const int64_t out_c = tgpig[2];
const int64_t kw = tpitg[0];
const int64_t kh = tpitg[1];
float v = 0.0f;
for (int64_t in_c = 0; in_c < args.IC; in_c++) {
int64_t in_y = out_y - kh;
if (in_y < 0 || in_y % args.s0) continue;
in_y /= args.s0;
if (in_y >= args.IH) continue;
int64_t in_x = out_x - kw;
if (in_x < 0 || in_x % args.s0) continue;
in_x /= args.s0;
if (in_x >= args.IW) continue;
const int64_t input_idx = (args.IW * args.IH) * in_c + (args.IW) * in_y + in_x;
const int64_t kernel_idx = (args.KH * args.KW * args.OC) * in_c + (args.KH * args.KW) * out_c + (args.KW) * kh + kw;
v += (float)src0[kernel_idx] * src1[input_idx];
}
const uint tid = tpitg.y * ntg.x + tpitg.x;
shared_sum[tid] = v;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float total = 0.0f;
const uint num_threads = ntg.x * ntg.y;
for (uint i = 0; i < num_threads; i++) {
total += shared_sum[i];
}
device float * dst_ptr = (device float *) (dst + out_x*args.nb0 + out_y * args.nb1 + out_c*args.nb2);
dst_ptr[0] = total;
}
}
template [[host_name("kernel_conv_transpose_2d_f32_f32")]]
kernel void kernel_conv_transpose_2d<float>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const float * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
template [[host_name("kernel_conv_transpose_2d_f16_f32")]]
kernel void kernel_conv_transpose_2d<half>(
constant ggml_metal_kargs_conv_transpose_2d & args,
device const half * src0,
device const float * src1,
device char * dst,
threadgroup float * shared_sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]);
kernel void kernel_upscale_f32(
kernel void kernel_upscale_f32(
constant ggml_metal_kargs_upscale & args,
constant ggml_metal_kargs_upscale & args,
device const char * src0,
device const char * src0,
...
@@ -5273,8 +5394,30 @@ kernel void kernel_flash_attn_ext(
...
@@ -5273,8 +5394,30 @@ kernel void kernel_flash_attn_ext(
half, half4, simdgroup_half8x8
half, half4, simdgroup_half8x8
//float, float4, simdgroup_float8x8
//float, float4, simdgroup_float8x8
#define FA_TYPES_F32 \
half, half4, simdgroup_half8x8, \
float, float4x4, simdgroup_float8x8, \
float, float4x4, simdgroup_float8x8, \
float, simdgroup_float8x8, \
float, float2, simdgroup_float8x8, \
float, float4, simdgroup_float8x8
//half, half4, simdgroup_half8x8
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>) flash_attn_ext_t;
template [[host_name("kernel_flash_attn_ext_f32_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f32_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f32_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f32_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f32_dk96_dv96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 96, 96>;
template [[host_name("kernel_flash_attn_ext_f32_dk112_dv112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 112, 112>;
template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 128, 128>;
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>;
template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>;
template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>;
template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>;
template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_f16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80, 80>;
...
@@ -5287,6 +5430,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
...
@@ -5287,6 +5430,7 @@ template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_bf16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 32, 32>;
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_bf16_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 40, 40>;
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64, 64>;
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
template [[host_name("kernel_flash_attn_ext_bf16_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80, 80>;
...
@@ -5299,6 +5443,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
...
@@ -5299,6 +5443,7 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_q4_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80, 80>;
...
@@ -5310,6 +5455,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
...
@@ -5310,6 +5455,7 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80, 80>;
...
@@ -5321,6 +5467,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
...
@@ -5321,6 +5467,7 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80, 80>;
...
@@ -5332,6 +5479,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
...
@@ -5332,6 +5479,7 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80, 80>;
...
@@ -5343,6 +5491,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
...
@@ -5343,6 +5491,7 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_at
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk40_dv40" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 40, 40>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk64_dv64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64, 64>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
template [[host_name("kernel_flash_attn_ext_q8_0_dk80_dv80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80, 80>;
...
@@ -5878,77 +6027,103 @@ kernel void kernel_flash_attn_ext_vec(
...
@@ -5878,77 +6027,103 @@ kernel void kernel_flash_attn_ext_vec(
float, float4, \
float, float4, \
float4
float4
#define FA_TYPES_F32 \
half4, \
float4, \
float4, \
float, \
float, float4, \
float4
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 32, 32, 4>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk
64
_dv
64
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
64, 64
,
2
>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk
32
_dv
32
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
32, 32
,
4
>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk32_dv32")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 32, 32, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 2>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk
9
6_dv
9
6")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
96, 96
,
4
>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk6
4
_dv6
4
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
64, 64
,
2
>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk64_dv64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 64, 64, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 96, 96, 4>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk
128_dv128
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
128, 128
,
1
>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk
96_dv96
")]]
kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
96, 96
,
4
>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk96_dv96")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 96, 96, 4>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 1>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk1
9
2_dv1
9
2")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 1
9
2, 1
9
2,
2
>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk12
8
_dv12
8
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 12
8
, 12
8
,
1
>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 128, 128, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 192, 2>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv12
8
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 12
8
, 2>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk192_dv1
9
2")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 192, 1
9
2, 2>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 192, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 192, 128, 2>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk
256_dv256
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
256, 256
,
1
>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk
192_dv128
")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4,
192, 128
,
2
>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk192_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 192, 128, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 256, 256, 1>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 256, 256, 1>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>;
template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>;
#if defined(GGML_METAL_HAS_BF16)
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_bf16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 576, 512, 2>;
#endif
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,
block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 576, 512, 2>;
#undef FA_TYPES
#undef FA_TYPES
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt
View file @
544b6739
cmake_minimum_required
(
VERSION 3.19
)
cmake_minimum_required
(
VERSION 3.19
)
cmake_policy
(
SET CMP0114 NEW
)
cmake_policy
(
SET CMP0114 NEW
)
cmake_policy
(
SET CMP0116 NEW
)
cmake_policy
(
SET CMP0116 NEW
)
if
(
POLICY CMP0147
)
# Parallel build custom build steps
cmake_policy
(
SET CMP0147 NEW
)
endif
()
find_package
(
Vulkan COMPONENTS glslc REQUIRED
)
find_package
(
Vulkan COMPONENTS glslc REQUIRED
)
if
(
CMAKE_CXX_COMPILER_ID STREQUAL
"MSVC"
)
# Parallel build object files
add_definitions
(
/MP
)
endif
()
function
(
detect_host_compiler
)
function
(
detect_host_compiler
)
if
(
CMAKE_HOST_SYSTEM_NAME STREQUAL
"Windows"
)
if
(
CMAKE_HOST_SYSTEM_NAME STREQUAL
"Windows"
)
find_program
(
HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH
)
find_program
(
HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH
)
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp
View file @
544b6739
...
@@ -97,8 +97,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
...
@@ -97,8 +97,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
#define GGML_VK_MAX_NODES 8192
#define GGML_VK_MAX_NODES 8192
#define MAX_VK_BUFFERS 256
#define VK_CHECK(err, msg) \
#define VK_CHECK(err, msg) \
do { \
do { \
vk::Result err_ = (err); \
vk::Result err_ = (err); \
...
@@ -387,6 +385,14 @@ enum shader_reduction_mode {
...
@@ -387,6 +385,14 @@ enum shader_reduction_mode {
static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t num_argsort_pipelines = 11;
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
static constexpr uint32_t num_topk_moe_pipelines = 10;
static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE };
static constexpr std::array topk_moe { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
struct vk_device_struct {
struct vk_device_struct {
std::recursive_mutex mutex;
std::recursive_mutex mutex;
...
@@ -584,6 +590,9 @@ struct vk_device_struct {
...
@@ -584,6 +590,9 @@ struct vk_device_struct {
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_opt_step_adamw_f32;
vk_pipeline pipeline_opt_step_sgd_f32;
vk_pipeline pipeline_opt_step_sgd_f32;
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
...
@@ -597,6 +606,9 @@ struct vk_device_struct {
...
@@ -597,6 +606,9 @@ struct vk_device_struct {
vk_pipeline pipeline_flash_attn_split_k_reduce;
vk_pipeline pipeline_flash_attn_split_k_reduce;
// [2] is {!norm, norm}
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
std::vector<vk_pipeline_ref> all_pipelines;
std::vector<vk_pipeline_ref> all_pipelines;
std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
...
@@ -940,6 +952,11 @@ struct vk_op_multi_add_push_constants {
...
@@ -940,6 +952,11 @@ struct vk_op_multi_add_push_constants {
static_assert(MAX_PARAMETER_COUNT == 12);
static_assert(MAX_PARAMETER_COUNT == 12);
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
struct vk_op_topk_moe_push_constants {
uint32_t n_rows;
uint32_t n_expert_used;
};
struct vk_op_add_id_push_constants {
struct vk_op_add_id_push_constants {
uint32_t ne0;
uint32_t ne0;
uint32_t ne1;
uint32_t ne1;
...
@@ -1089,6 +1106,19 @@ struct vk_op_rwkv_wkv7_push_constants {
...
@@ -1089,6 +1106,19 @@ struct vk_op_rwkv_wkv7_push_constants {
uint32_t C;
uint32_t C;
uint32_t H;
uint32_t H;
};
};
struct vk_op_ssm_scan_push_constants {
uint32_t nb02, nb03, nb12, nb13;
uint32_t nb21, nb22, nb31;
uint32_t nb42, nb43, nb52, nb53;
uint32_t s_off;
uint32_t n_head, d_head, n_group, n_tok;
};
struct vk_op_ssm_conv_push_constants {
uint32_t nb01, nb02;
uint32_t nb11;
uint32_t dst_nb0, dst_nb1, dst_nb2;
uint32_t nc, ncs, nr, n_t, n_s;
};
struct vk_op_conv2d_push_constants {
struct vk_op_conv2d_push_constants {
uint32_t Cout;
uint32_t Cout;
...
@@ -1281,7 +1311,6 @@ struct ggml_vk_garbage_collector {
...
@@ -1281,7 +1311,6 @@ struct ggml_vk_garbage_collector {
std::vector<vk_semaphore> tl_semaphores;
std::vector<vk_semaphore> tl_semaphores;
std::vector<vk_semaphore> semaphores;
std::vector<vk_semaphore> semaphores;
std::vector<vk::Event> events;
std::vector<vk::Event> events;
std::vector<vk_buffer> temp_buffers;
std::vector<vk_context> contexts;
std::vector<vk_context> contexts;
};
};
...
@@ -1452,8 +1481,6 @@ struct ggml_backend_vk_context {
...
@@ -1452,8 +1481,6 @@ struct ggml_backend_vk_context {
// and set to true after the buffer contents are consumed.
// and set to true after the buffer contents are consumed.
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
vk_buffer buffer_pool[MAX_VK_BUFFERS];
vk_context_ref compute_ctx;
vk_context_ref compute_ctx;
vk_context_ref transfer_ctx;
vk_context_ref transfer_ctx;
...
@@ -2651,11 +2678,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -2651,11 +2678,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
} \
} \
}
}
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {
if (device->coopmat1_fa_support) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
...
@@ -2663,6 +2692,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -2663,6 +2692,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif
#endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
if (device->coopmat2) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
...
@@ -3590,6 +3620,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -3590,6 +3620,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
} else {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
}
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
...
@@ -3700,6 +3740,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
...
@@ -3700,6 +3740,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
}
for (auto &c : compiles) {
for (auto &c : compiles) {
c.wait();
c.wait();
}
}
...
@@ -4690,7 +4735,14 @@ static void ggml_vk_instance_init() {
...
@@ -4690,7 +4735,14 @@ static void ggml_vk_instance_init() {
vk::PhysicalDeviceIDProperties old_id;
vk::PhysicalDeviceIDProperties old_id;
old_props.pNext = &old_id;
old_props.pNext = &old_id;
devices[k].getProperties2(&old_props);
devices[k].getProperties2(&old_props);
return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
bool equals = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
equals = equals || (
old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))
);
return equals;
}
}
);
);
if (old_device == vk_instance.device_indices.end()) {
if (old_device == vk_instance.device_indices.end()) {
...
@@ -4728,6 +4780,7 @@ static void ggml_vk_instance_init() {
...
@@ -4728,6 +4780,7 @@ static void ggml_vk_instance_init() {
#endif
#endif
break;
break;
}
}
driver_priorities[vk::DriverId::eMesaDozen] = 100;
if (driver_priorities.count(old_driver.driverID)) {
if (driver_priorities.count(old_driver.driverID)) {
old_priority = driver_priorities[old_driver.driverID];
old_priority = driver_priorities[old_driver.driverID];
...
@@ -5101,71 +5154,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
...
@@ -5101,71 +5154,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
}
}
static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")");
VK_LOG_MEMORY("ggml_vk_pool_malloc");
int best_i = -1;
size_t best_size = std::numeric_limits<size_t>::max(); //smallest unused buffer that fits our needs
int worst_i = -1;
size_t worst_size = 0; //largest unused buffer seen so far
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
vk_buffer &b = ctx->buffer_pool[i];
if (b != nullptr && b->size >= size && b->size < best_size) {
best_i = i;
best_size = b->size;
}
if (b != nullptr && b->size > worst_size) {
worst_i = i;
worst_size = b->size;
}
}
if(best_i != -1) {
//found the smallest buffer that fits our needs
vk_buffer b = ctx->buffer_pool[best_i];
ctx->buffer_pool[best_i].reset();
return b;
}
if(worst_i != -1) {
//no buffer that fits our needs, resize largest one to save memory
vk_buffer& b = ctx->buffer_pool[worst_i];
ggml_vk_destroy_buffer(b);
}
return ggml_vk_create_buffer_device(ctx->device, size);
}
static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) {
VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")");
for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
vk_buffer& b = ctx->buffer_pool[i];
if (b == nullptr) {
b = buffer;
return;
}
}
std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl;
ggml_vk_destroy_buffer(buffer);
}
// Returns an available temporary buffer that may only be used temporarily, it will be reused
static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) {
// Try to find existing temp buffer with enough capacity
for (auto& buffer : ctx->gc.temp_buffers) {
if (buffer->size >= size) {
return buffer;
}
}
VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")");
// Otherwise create new buffer
vk_buffer buf = ggml_vk_pool_malloc(ctx, size);
ctx->gc.temp_buffers.push_back(buf);
return buf;
}
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
vk_buffer buf = ggml_vk_create_buffer(device, size,
vk_buffer buf = ggml_vk_create_buffer(device, size,
...
@@ -7459,8 +7447,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
...
@@ -7459,8 +7447,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
}
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
// For F32, the shader treats it as a block of size 4 (for vec4 loads)
if (k->type == GGML_TYPE_F32) {
k_stride /= 4;
}
if (v->type == GGML_TYPE_F32) {
v_stride /= 4;
}
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
bool aligned = (KV % alignment) == 0 &&
bool aligned = (KV % alignment) == 0 &&
...
@@ -7974,6 +7970,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
...
@@ -7974,6 +7970,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines);
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
return ctx->device->pipeline_topk_moe[idx][with_norm];
}
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
}
}
...
@@ -8089,6 +8092,21 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
...
@@ -8089,6 +8092,21 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rwkv_wkv7_f32;
return ctx->device->pipeline_rwkv_wkv7_f32;
}
}
return nullptr;
return nullptr;
case GGML_OP_SSM_SCAN:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t d_state = src0->ne[0];
if (d_state == 128) {
return ctx->device->pipeline_ssm_scan_f32_d128;
} else if (d_state == 256) {
return ctx->device->pipeline_ssm_scan_f32_d256;
}
}
return nullptr;
case GGML_OP_SSM_CONV:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_ssm_conv_f32;
}
return nullptr;
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_ADAMW:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_opt_step_adamw_f32;
return ctx->device->pipeline_opt_step_adamw_f32;
...
@@ -8583,6 +8601,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
...
@@ -8583,6 +8601,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
}
}
}
}
break;
break;
case GGML_OP_SSM_CONV:
{
const uint32_t nr = src0->ne[1];
const uint32_t n_t = dst->ne[1];
const uint32_t n_s = dst->ne[2];
elements = { nr, n_t, n_s };
}
break;
default:
default:
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
elements = { (uint32_t)ggml_nelements(src0), 1, 1 };
break;
break;
...
@@ -9029,6 +9055,117 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
...
@@ -9029,6 +9055,117 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
);
);
}
}
static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const ggml_tensor * src3 = dst->src[3];
const ggml_tensor * src4 = dst->src[4];
const ggml_tensor * src5 = dst->src[5];
GGML_ASSERT(dst->buffer != nullptr);
const uint32_t head_dim = src0->ne[1];
const uint32_t n_head = src1->ne[1];
const uint32_t n_group = src4->ne[1];
const uint32_t n_tok = src1->ne[2];
const uint32_t n_seq = src1->ne[3];
bool is_mamba2 = (src3->nb[1] == sizeof(float));
GGML_ASSERT(is_mamba2);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, dst->op);
GGML_ASSERT(pipeline != nullptr);
if (dryrun) {
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
return;
}
const int64_t s_off = ggml_nelements(src1) * sizeof(float);
const vk_op_ssm_scan_push_constants pc = {
(uint32_t)src0->nb[2], (uint32_t)src0->nb[3],
(uint32_t)src1->nb[2], (uint32_t)src1->nb[3],
(uint32_t)src2->nb[1], (uint32_t)src2->nb[2],
(uint32_t)src3->nb[1],
(uint32_t)src4->nb[2], (uint32_t)src4->nb[3],
(uint32_t)src5->nb[2], (uint32_t)src5->nb[3],
(uint32_t)s_off,
n_head, head_dim, n_group, n_tok
};
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
ggml_backend_vk_buffer_context * src_buf_ctxs[GGML_MAX_SRC];
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
}
vk_buffer d_D = nullptr, d_srcs[GGML_MAX_SRC] = { nullptr };
size_t dst_offset = 0, src_offsets[GGML_MAX_SRC] = { 0 };
bool dst_uma = false, srcs_uma[GGML_MAX_SRC] = { false };
if (ctx->device->uma) {
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
srcs_uma[i] = d_srcs[i] != nullptr;
}
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
dst_uma = d_D != nullptr;
}
if (!dst_uma) {
d_D = dst_buf_ctx->dev_buffer;
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
}
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
if (!srcs_uma[i]) {
d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
}
}
size_t dst_size = ggml_nbytes(dst);
size_t src_sizes[GGML_MAX_SRC];
for (int i = 0; i < GGML_MAX_SRC && dst->src[i] != nullptr; i++) {
src_sizes[i] = ggml_nbytes(dst->src[i]);
}
std::array<uint32_t, 3> elements;
const int splitH = 16;
const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH);
const uint32_t num_workgroups_y = n_seq;
elements = { num_workgroups_x, num_workgroups_y, 1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
vk_subbuffer{ d_D, dst_offset, dst_size }
}, pc, elements);
}
static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SSM_CONV, {
(uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
(uint32_t)src1->nb[1],
(uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
(uint32_t)src1->ne[0],
(uint32_t)src0->ne[0],
(uint32_t)src0->ne[1],
(uint32_t)dst->ne[1],
(uint32_t)dst->ne[2],
}, dryrun);
}
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
const ggml_tensor * x = dst->src[0];
const ggml_tensor * x = dst->src[0];
const ggml_tensor * g = dst->src[1];
const ggml_tensor * g = dst->src[1];
...
@@ -9425,6 +9562,87 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
...
@@ -9425,6 +9562,87 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
}
}
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
ggml_tensor * ids = cgraph->nodes[node_idx + 3];
GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
const int n_experts = logits->ne[0];
const int n_rows = logits->ne[1];
const int n_expert_used = weights->ne[1];
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX);
if (dryrun) {
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
return;
}
ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context;
ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context;
ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
vk_buffer d_logits = nullptr;
size_t logits_buf_offset = 0;
vk_buffer d_weights = nullptr;
size_t weights_buf_offset = 0;
vk_buffer d_ids = nullptr;
size_t ids_buf_offset = 0;
bool logits_uma = false;
bool weights_uma = false;
bool ids_uma = false;
if (ctx->device->uma) {
ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset);
ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset);
ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
logits_uma = d_logits != nullptr;
weights_uma = d_weights != nullptr;
ids_uma = d_ids != nullptr;
}
if (!logits_uma) {
d_logits = logits_buf_ctx->dev_buffer;
logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs;
GGML_ASSERT(d_logits != nullptr);
}
if (!weights_uma) {
d_weights = weights_buf_ctx->dev_buffer;
weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs;
GGML_ASSERT(d_weights != nullptr);
}
if (!ids_uma) {
d_ids = ids_buf_ctx->dev_buffer;
ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
GGML_ASSERT(d_ids != nullptr);
}
vk_op_topk_moe_push_constants pc;
pc.n_rows = n_rows;
pc.n_expert_used = n_expert_used;
GGML_ASSERT(n_expert_used <= n_experts);
const uint32_t rows_per_block = 4;
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{
ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset),
ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset),
ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset),
}, pc, elements);
}
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
const int n_dims = ((int32_t *) dst->op_params)[1];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int mode = ((int32_t *) dst->op_params)[2];
...
@@ -10861,6 +11079,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
...
@@ -10861,6 +11079,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_RWKV_WKV7:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
case GGML_OP_LEAKY_RELU:
case GGML_OP_LEAKY_RELU:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_FLASH_ATTN_EXT:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_ADAMW:
...
@@ -11008,11 +11228,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
...
@@ -11008,11 +11228,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
ctx->unsynced_nodes_read.clear();
ctx->unsynced_nodes_read.clear();
ggml_vk_sync_buffers(ctx, compute_ctx);
ggml_vk_sync_buffers(ctx, compute_ctx);
}
}
// Add the last fused node and all fused source nodes to the unsynchronized list.
// Add all fused nodes to the unsynchronized lists.
const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
ctx->unsynced_nodes_written.push_back(last_node);
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
// Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
ctx->unsynced_nodes_written.push_back(cur_node);
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
if (!cur_node->src[j]) {
if (!cur_node->src[j]) {
continue;
continue;
...
@@ -11179,7 +11399,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
...
@@ -11179,7 +11399,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
break;
case GGML_OP_SOFT_MAX:
case GGML_OP_SOFT_MAX:
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
if (ctx->num_additional_fused_ops) {
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
} else {
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
}
break;
break;
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SOFT_MAX_BACK:
...
@@ -11278,6 +11502,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
...
@@ -11278,6 +11502,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break;
break;
case GGML_OP_SSM_SCAN:
ggml_vk_ssm_scan(ctx, compute_ctx, node, dryrun);
break;
case GGML_OP_SSM_CONV:
ggml_vk_ssm_conv(ctx, compute_ctx, node, dryrun);
break;
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_ADAMW:
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
...
@@ -11389,6 +11623,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
...
@@ -11389,6 +11623,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_2D_DW:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_RWKV_WKV7:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
case GGML_OP_LEAKY_RELU:
case GGML_OP_LEAKY_RELU:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_REPEAT_BACK:
...
@@ -11498,10 +11734,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
...
@@ -11498,10 +11734,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
// Clean up after graph processing is done
// Clean up after graph processing is done
static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
for (auto& buffer : ctx->gc.temp_buffers) {
ggml_vk_pool_free(ctx, buffer);
}
ctx->gc.temp_buffers.clear();
ctx->prealloc_y_last_pipeline_used = {};
ctx->prealloc_y_last_pipeline_used = {};
ctx->unsynced_nodes_written.clear();
ctx->unsynced_nodes_written.clear();
...
@@ -11544,10 +11776,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
...
@@ -11544,10 +11776,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
ggml_vk_destroy_buffer(ctx->prealloc_split_k);
ctx->prealloc_y_last_pipeline_used = nullptr;
ctx->prealloc_y_last_pipeline_used = nullptr;
for (auto& buffer : ctx->buffer_pool) {
ggml_vk_destroy_buffer(buffer);
}
ctx->prealloc_size_x = 0;
ctx->prealloc_size_x = 0;
ctx->prealloc_size_y = 0;
ctx->prealloc_size_y = 0;
ctx->prealloc_size_split_k = 0;
ctx->prealloc_size_split_k = 0;
...
@@ -11988,6 +12216,120 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
...
@@ -11988,6 +12216,120 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
return true;
return true;
}
}
static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
int node_idx, bool with_norm) {
if (with_norm) {
if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
return false;
}
}
} else {
if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
return false;
}
for (size_t i = 0; i < topk_moe.size(); ++i) {
if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
return false;
}
}
}
const ggml_tensor * softmax = cgraph->nodes[node_idx + 0];
const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
const float * op_params = (const float *)softmax->op_params;
float scale = op_params[0];
float max_bias = op_params[1];
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
return false;
}
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
const int n_expert = softmax->ne[0];
// n_expert must be a power of 2
if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
return false;
}
// Check that the nodes don't have any unexpected uses
const ggml_tensor * reshape1 = cgraph->nodes[node_idx + 1];
const ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
const ggml_tensor * view = cgraph->nodes[node_idx + 3];
const ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
const ggml_tensor * reshape5 = with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
const ggml_tensor * sum_rows = with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
const ggml_tensor * div = with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
const ggml_tensor * reshape8 = with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
// softmax is used by reshape and argsort
if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
reshape1->src[0] != softmax ||
argsort->src[0] != softmax) {
return false;
}
// reshape is used by get_rows
if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
get_rows->src[0] != reshape1) {
return false;
}
// argsort is used by view
if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
view->src[0] != argsort) {
return false;
}
// view is written (via argsort), we can skip checking it
if (with_norm) {
// get_rows is used by reshape
if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
reshape5->src[0] != get_rows) {
return false;
}
// reshape is used by sum_rows and div
if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
sum_rows->src[0] != reshape5 ||
div->src[0] != reshape5) {
return false;
}
// sum_rows is used by div
if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
div->src[1] != sum_rows) {
return false;
}
// div/reshape are written
if (reshape8->src[0] != div) {
return false;
}
}
if (!ctx->device->subgroup_arithmetic ||
!ctx->device->subgroup_shuffle ||
!ctx->device->subgroup_require_full_support ||
ctx->device->disable_fusion) {
return false;
}
return true;
}
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
const ggml_tensor *first_node = cgraph->nodes[node_idx];
const ggml_tensor *first_node = cgraph->nodes[node_idx];
...
@@ -12063,6 +12405,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
...
@@ -12063,6 +12405,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
ctx->num_additional_fused_ops = topk_moe.size() - 1;
}
}
}
}
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
...
@@ -12160,6 +12506,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
...
@@ -12160,6 +12506,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = num_adds - 1;
ctx->num_additional_fused_ops = num_adds - 1;
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
} else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ctx->num_additional_fused_ops = 1;
ctx->num_additional_fused_ops = 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
} else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
ctx->num_additional_fused_ops = topk_moe.size() - 1;
}
}
}
}
...
@@ -12167,10 +12517,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
...
@@ -12167,10 +12517,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
bool submit = (submitted_nodes >= nodes_per_submit) ||
bool submit = (submitted_nodes >= nodes_per_submit) ||
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
(i + ctx->num_additional_fused_ops
=
= last_node) ||
(i + ctx->num_additional_fused_ops
>
= last_node) ||
(almost_ready && !ctx->almost_ready_fence_pending);
(almost_ready && !ctx->almost_ready_fence_pending);
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops
=
= last_node, almost_ready, submit);
bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops
>
= last_node, almost_ready, submit);
if (vk_perf_logger_enabled) {
if (vk_perf_logger_enabled) {
if (ctx->compute_ctx.expired()) {
if (ctx->compute_ctx.expired()) {
...
@@ -12292,6 +12642,25 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
...
@@ -12292,6 +12642,25 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
while (first_unused < graph->n_nodes) {
while (first_unused < graph->n_nodes) {
std::vector<int> current_set;
std::vector<int> current_set;
// Avoid reordering topk_moe_norm
if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
bool is_topk_moe_norm = true;
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
is_topk_moe_norm = false;
}
}
if (is_topk_moe_norm) {
for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
new_order.push_back(graph->nodes[first_unused + j]);
used[first_unused + j] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
first_unused++;
}
continue;
}
}
// First, grab the next unused node.
// First, grab the next unused node.
current_set.push_back(first_unused);
current_set.push_back(first_unused);
...
@@ -12797,6 +13166,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
...
@@ -12797,6 +13166,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
}
switch (op->src[1]->type) {
switch (op->src[1]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_0:
// supported in scalar and coopmat2 paths
// supported in scalar and coopmat2 paths
...
@@ -13004,6 +13374,47 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
...
@@ -13004,6 +13374,47 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
case GGML_OP_RWKV_WKV7:
return true;
return true;
case GGML_OP_SSM_SCAN:
{
for (int i = 0; i < 6; i++) {
if (op->src[i] && ggml_is_quantized(op->src[i]->type)) {
return false;
}
}
if (op->src[6] && op->src[6]->type != GGML_TYPE_I32) {
return false;
}
if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) {
return false;
}
const uint32_t d_state = op->src[0]->ne[0];
const uint32_t head_dim = op->src[0]->ne[1];
bool is_mamba2 = (op->src[3] && op->src[3]->nb[1] == sizeof(float));
if (!is_mamba2) {
return false;
}
if ((d_state != 128 && d_state != 256) || head_dim % 16 != 0) {
return false;
}
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
const uint32_t SPLIT_H = 16;
size_t stateC_size = SPLIT_H * d_state * sizeof(float);
if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) {
return false;
}
return true;
}
case GGML_OP_SSM_CONV:
return true;
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D:
...
@@ -13386,14 +13797,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
...
@@ -13386,14 +13797,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
struct ggml_context * ggml_ctx = ggml_init(iparams);
struct ggml_context * ggml_ctx = ggml_init(iparams);
std::array<struct ggml_tensor *,
6
> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr};
std::array<struct ggml_tensor *,
GGML_MAX_SRC
> src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr, nullptr,
nullptr};
std::array<size_t,
6
> src_size = {
0, 0, 0, 0, 0, 0
};
std::array<size_t,
GGML_MAX_SRC
> src_size = {};
std::array<void *,
6
> src_buffer = {
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr
};
std::array<void *,
GGML_MAX_SRC
> src_buffer = {};
const char * srci_name[
6
] = {"src0", "src1", "src2", "src3", "src4", "src5"};
const char * srci_name[
GGML_MAX_SRC
] = {"src0", "src1", "src2", "src3", "src4", "src5"
, "src6", "src7", "src8", "src9"
};
struct ggml_tensor * tensor_clone = nullptr;
struct ggml_tensor * tensor_clone = nullptr;
for (int i = 0; i <
6
; i++) {
for (int i = 0; i <
GGML_MAX_SRC
; i++) {
ggml_tensor * srci = tensor->src[i];
ggml_tensor * srci = tensor->src[i];
if (fused_rms_norm_mul) {
if (fused_rms_norm_mul) {
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
...
@@ -13700,6 +14111,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
...
@@ -13700,6 +14111,11 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
src_clone[2]);
src_clone[2]);
} else if (tensor->op == GGML_OP_ADD_ID) {
} else if (tensor->op == GGML_OP_ADD_ID) {
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]);
} else if (tensor->op == GGML_OP_SSM_SCAN) {
tensor_clone = ggml_ssm_scan(ggml_ctx, src_clone[0], src_clone[1], src_clone[2],
src_clone[3], src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_SSM_CONV) {
tensor_clone = ggml_ssm_conv(ggml_ctx, src_clone[0], src_clone[1]);
}
}
else {
else {
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
...
@@ -13721,7 +14137,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
...
@@ -13721,7 +14137,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
memcpy(comp_result, tensor_clone->data, comp_size);
memcpy(comp_result, tensor_clone->data, comp_size);
memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS);
for (int i = 0; i <
6
; i++) {
for (int i = 0; i <
GGML_MAX_SRC
; i++) {
if (src_buffer[i] != nullptr) {
if (src_buffer[i] != nullptr) {
free(src_buffer[i]);
free(src_buffer[i]);
}
}
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl
View file @
544b6739
#include "types.glsl"
#include "types.glsl"
layout
(
buffer_reference
,
std430
,
buffer_reference_align
=
16
)
buffer
decodeBufF32
{
vec4
block
;
};
float16_t
dequantFuncF32
(
const
in
decodeBufF32
bl
,
const
in
uint
blockCoords
[
2
],
const
in
uint
coordInBlock
[
2
])
{
const
vec4
v
=
bl
.
block
;
const
uint
idx
=
coordInBlock
[
1
];
const
f16vec4
vf16
=
f16vec4
(
v
);
return
vf16
[
idx
];
}
layout
(
buffer_reference
,
std430
,
buffer_reference_align
=
2
)
buffer
decodeBufQ4_0
{
layout
(
buffer_reference
,
std430
,
buffer_reference_align
=
2
)
buffer
decodeBufQ4_0
{
block_q4_0_packed16
block
;
block_q4_0_packed16
block
;
};
};
...
@@ -717,4 +729,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
...
@@ -717,4 +729,6 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords
#define dequantFuncA dequantFuncIQ4_NL
#define dequantFuncA dequantFuncIQ4_NL
#elif defined(DATA_A_MXFP4)
#elif defined(DATA_A_MXFP4)
#define dequantFuncA dequantFuncMXFP4
#define dequantFuncA dequantFuncMXFP4
#elif defined(DATA_A_F32)
#define dequantFuncA dequantFuncF32
#endif
#endif
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
View file @
544b6739
...
@@ -345,7 +345,7 @@ void main() {
...
@@ -345,7 +345,7 @@ void main() {
float Lfrcp[Br];
float Lfrcp[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
Lfrcp[r] =
(Lf[r] == 0.0) ? 0.0 : (
1.0 / Lf[r]
)
;
}
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
View file @
544b6739
...
@@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];};
...
@@ -64,13 +64,31 @@ layout (binding = 4) readonly buffer S {float data_s[];};
layout
(
binding
=
5
)
writeonly
buffer
O
{
D_TYPE
data_o
[];};
layout
(
binding
=
5
)
writeonly
buffer
O
{
D_TYPE
data_o
[];};
#if defined(A_TYPE_PACKED16)
#define BINDING_IDX_K 0
#define BINDING_IDX_K 0
#define BINDING_IDX_V 1
#define BINDING_IDX_V 1
#if defined(DATA_A_F32)
layout
(
binding
=
1
)
readonly
buffer
K_PACKED
{
vec4
k_data_packed
[];}
k_packed
;
layout
(
binding
=
2
)
readonly
buffer
V_PACKED
{
vec4
v_data_packed
[];}
v_packed
;
#elif defined(A_TYPE_PACKED16)
layout
(
binding
=
1
)
readonly
buffer
K_PACKED16
{
A_TYPE_PACKED16
k_data_packed16
[];}
k_packed
;
layout
(
binding
=
1
)
readonly
buffer
K_PACKED16
{
A_TYPE_PACKED16
k_data_packed16
[];}
k_packed
;
layout
(
binding
=
2
)
readonly
buffer
V_PACKED16
{
A_TYPE_PACKED16
v_data_packed16
[];}
v_packed
;
layout
(
binding
=
2
)
readonly
buffer
V_PACKED16
{
A_TYPE_PACKED16
v_data_packed16
[];}
v_packed
;
#endif
#endif
#if defined(DATA_A_F32)
#undef BLOCK_SIZE
#define BLOCK_SIZE 4
#define BLOCK_BYTE_SIZE 16
vec4
dequantize4
(
uint
ib
,
uint
iqs
,
uint
a_offset
,
uint
binding_idx
)
{
// iqs is currently always zero in the flash attention shaders
if
(
binding_idx
==
BINDING_IDX_K
)
{
return
k_packed
.
k_data_packed
[
a_offset
+
ib
];
}
else
{
return
v_packed
.
v_data_packed
[
a_offset
+
ib
];
}
}
#endif
#if defined(DATA_A_Q4_0)
#if defined(DATA_A_Q4_0)
#define BLOCK_BYTE_SIZE 18
#define BLOCK_BYTE_SIZE 18
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
View file @
544b6739
...
@@ -380,7 +380,7 @@ void main() {
...
@@ -380,7 +380,7 @@ void main() {
float Lfrcp[rows_per_thread];
float Lfrcp[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lfrcp[r] = 1.0 / Lf[r];
Lfrcp[r] =
(Lf[r] == 0.0) ? 0.0 : (
1.0 / Lf[r]
)
;
}
}
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
View file @
544b6739
...
@@ -121,7 +121,11 @@ void main() {
...
@@ -121,7 +121,11 @@ void main() {
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
#if defined(ACC_TYPE_MAX)
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-ACC_TYPE_MAX / ACC_TYPE(2));
#else
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
#endif
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
...
@@ -294,7 +298,7 @@ void main() {
...
@@ -294,7 +298,7 @@ void main() {
[[unroll]]
[[unroll]]
for (int k = 0; k < Ldiag.length(); ++k) {
for (int k = 0; k < Ldiag.length(); ++k) {
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
Ldiag[k] =
(Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (
ACC_TYPE(1.0) / Ldiag[k]
)
;
}
}
O = Ldiag*O;
O = Ldiag*O;
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
View file @
544b6739
...
@@ -91,7 +91,7 @@ void main() {
...
@@ -91,7 +91,7 @@ void main() {
L = L*ms + vs;
L = L*ms + vs;
}
}
L = 1.0 / L;
L =
(L == 0.0) ? 0.0 :
1.0 / L;
// D dimension is split across workgroups in the y dimension
// D dimension is split across workgroups in the y dimension
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
View file @
544b6739
...
@@ -313,12 +313,12 @@ void main() {
...
@@ -313,12 +313,12 @@ void main() {
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f);
}
}
#else
#else
ACC_TYPE sums[WMITER * TM * WNITER * TN];
ACC_TYPE
_VEC2
sums[WMITER * TM * WNITER * TN
/2
];
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
FLOAT_TYPE_VEC2 cache_a[WMITER * TM];
FLOAT_TYPE_VEC2 cache_b
[TN]
;
FLOAT_TYPE_VEC2 cache_b;
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN
/2
; i++) {
sums[i] = ACC_TYPE
(
0.0f);
sums[i] = ACC_TYPE
_VEC2(0.0f,
0.0f);
}
}
#endif
#endif
...
@@ -360,20 +360,22 @@ void main() {
...
@@ -360,20 +360,22 @@ void main() {
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i];
}
}
}
}
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint j = 0; j < TN; j++) {
cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i];
}
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cc = 0; cc < TN; cc++) {
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
cache_b = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + cc) * SHMEM_STRIDE + i];
const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr;
sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx]));
[[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) {
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
// [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr]
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x));
sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y));
}
}
}
}
}
}
}
}
}
}
#endif
#endif
...
@@ -388,8 +390,9 @@ void main() {
...
@@ -388,8 +390,9 @@ void main() {
}
}
}
}
#else
#else
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) {
sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
sums[i].x = clamp(sums[i].x, -ACC_TYPE_MAX, ACC_TYPE_MAX);
sums[i].y = clamp(sums[i].y, -ACC_TYPE_MAX, ACC_TYPE_MAX);
}
}
#endif
#endif
#endif
#endif
...
@@ -463,14 +466,21 @@ void main() {
...
@@ -463,14 +466,21 @@ void main() {
const u16vec2 row_idx = row_ids[row_i - ic * BN];
const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID
#endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) {
[[unroll]] for (uint cr = 0; cr < TM / 2; cr++) {
const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr;
#ifdef MUL_MAT_ID
#ifdef MUL_MAT_ID
if (dr_warp + cr < p.M) {
if (dr_warp + 2 * cr < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
}
if (dr_warp + 2 * cr + 1 < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
}
}
#else
#else
if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x);
}
if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) {
data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y);
}
}
#endif // MUL_MAT_ID
#endif // MUL_MAT_ID
}
}
...
...
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp
0 → 100644
View file @
544b6739
#version 450
#extension GL_EXT_control_flow_attributes : require
#include "types.glsl"
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(binding = 0) readonly buffer Src0 { float src0[]; };
layout(binding = 1) readonly buffer Src1 { float src1[]; };
layout(binding = 2) buffer Dst { float dst[]; };
layout(push_constant) uniform PushConstants {
uint nb01; uint nb02;
uint nb11;
uint dst_nb0; uint dst_nb1; uint dst_nb2;
uint nc; uint ncs; uint nr; uint n_t; uint n_s;
};
void main() {
const uint global_thread_id = gl_GlobalInvocationID.x;
const uint i2 = gl_WorkGroupID.y;
const uint i3 = gl_WorkGroupID.z;
if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
return;
}
const uint i1 = global_thread_id;
const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
const uint src1_base = i1 * (nb11 / 4);
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
float sum = 0.0;
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
const uint src0_idx = src0_base + i0;
const uint src1_idx = src1_base + i0;
sum += src0[src0_idx] * src1[src1_idx];
}
dst[dst_idx] = sum;
}
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp
0 → 100644
View file @
544b6739
#version 450
#extension GL_EXT_control_flow_attributes : require
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
#include "types.glsl"
layout(constant_id = 0) const uint D_STATE = 128;
layout(constant_id = 1) const uint SUBGROUP_SIZE = 32;
layout(constant_id = 2) const uint SPLIT_H = 16;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(binding = 0) readonly buffer Src0 { float s0[]; };
layout(binding = 1) readonly buffer Src1 { float x[]; };
layout(binding = 2) readonly buffer Src2 { float dt[]; };
layout(binding = 3) readonly buffer Src3 { float A[]; };
layout(binding = 4) readonly buffer Src4 { float B[]; };
layout(binding = 5) readonly buffer Src5 { float C[]; };
layout(binding = 6) readonly buffer Src6 { int ids[]; };
layout(binding = 7) buffer Dst { float d[]; };
layout(push_constant) uniform PushConstants {
uint nb02; uint nb03; uint nb12; uint nb13;
uint nb21; uint nb22; uint nb31;
uint nb42; uint nb43; uint nb52; uint nb53;
uint s_off;
uint n_head;
uint d_head;
uint n_group;
uint n_tok;
};
float softplus(float x) {
if (x <= 20.0) {
return log(1.0 + exp(x));
} else {
return x;
}
}
shared float stateC[SPLIT_H * D_STATE];
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head;
const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4;
const uint seq_idx = gl_WorkGroupID.y;
const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4;
const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4;
const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4;
const uint A_base_idx = (head_idx * nb31) / 4;
const uint B_base_idx = (seq_idx * nb43 + group_off) / 4;
const uint C_base_idx = (seq_idx * nb53 + group_off) / 4;
const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H;
const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4;
const uint stride_x = nb12 / 4;
const uint stride_dt = nb21 / 4;
const uint stride_B = nb42 / 4;
const uint stride_C = nb52 / 4;
const uint stride_y = n_head * d_head;
float state[SPLIT_H];
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
state[j] = s0[s0_base_idx + j * D_STATE + tid];
}
for (uint i = 0; i < n_tok; i++) {
const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]);
const float dA = exp(dt_soft_plus * A[A_base_idx]);
const float B_val = B[B_base_idx + i * stride_B + tid];
const float C_val = C[C_base_idx + i * stride_C + tid];
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus;
state[j] = (state[j] * dA) + (B_val * x_dt);
stateC[j * D_STATE + tid] = state[j] * C_val;
}
barrier();
[[unroll]]
for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
[[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
stateC[k] += stateC[k + w];
}
}
barrier();
}
[[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
const uint idx = (tid % SUBGROUP_SIZE) +
D_STATE * (tid / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
const uint max_idx = SUBGROUP_SIZE - 1 +
D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
if (idx < SPLIT_H * D_STATE ||
max_idx < SPLIT_H * D_STATE) {
float sc;
#if USE_SUBGROUP_ADD
sc = stateC[idx];
sc = subgroupAdd(sc);
#else
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
if (idx + offset < SPLIT_H * D_STATE) {
stateC[idx] += stateC[idx + offset];
}
barrier();
}
if (tid % SUBGROUP_SIZE == 0) {
sc = stateC[idx];
}
#endif
if (tid % SUBGROUP_SIZE == 0) {
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
d[y_base_idx + i * stride_y + k] = sc;
}
}
}
barrier();
}
[[unroll]] for (uint j = 0; j < SPLIT_H; j++) {
d[s_base_idx + j * D_STATE + tid] = state[j];
}
}
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
0 → 100644
View file @
544b6739
#version 450
#extension GL_EXT_control_flow_attributes : require
#extension GL_KHR_shader_subgroup_basic : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_shuffle : enable
#include "types.glsl"
layout (push_constant) uniform parameter
{
uint n_rows;
uint n_expert_used;
};
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts = 512;
layout(constant_id = 2) const bool with_norm = true;
const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
void main() {
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
if (row >= n_rows) {
return;
}
const uint logits_offset = n_experts * row;
const uint weights_offset = n_expert_used * row;
const uint ids_offset = n_experts * row;
float logits_r[experts_per_thread];
const float INFINITY = 1.0 / 0.0;
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + gl_LocalInvocationID.x;
logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
}
float max_val = logits_r[0];
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const float val = logits_r[i];
max_val = max(val, max_val);
}
max_val = subgroupMax(max_val);
float wt[experts_per_thread];
float tmp = 0.f;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
const float val = logits_r[i];
wt[i] = exp(val - max_val);
tmp += wt[i];
}
tmp = subgroupAdd(tmp);
const float inv_sum = 1.0f / tmp;
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = wt[i] * inv_sum;
}
// at this point, each thread holds a portion of softmax,
// we do the argmax reduce over n_expert_used, each time marking
// the expert weight as -inf to exclude from the next iteration
float wt_sum = 0.f;
float output_weights[experts_per_thread];
for (int k = 0; k < n_expert_used; k++) {
float max_val = wt[0];
uint max_expert = gl_LocalInvocationID.x;
[[unroll]]
for (int i = 1; i < experts_per_thread; i++) {
const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
[[unroll]]
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = subgroupShuffleXor(max_val, mask);
const uint expert = subgroupShuffleXor(max_expert, mask);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
}
}
if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
output_weights[k / WARP_SIZE] = max_val;
}
if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
ids[ids_offset + k] = max_expert;
if (with_norm) {
wt_sum += max_val;
}
}
}
if (with_norm) {
wt_sum = subgroupAdd(wt_sum);
const float inv_sum = 1.0f / wt_sum;
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
output_weights[i] *= inv_sum;
}
}
[[unroll]]
for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
if (idx < n_expert_used) {
weights[weights_offset + idx] = output_weights[i];
}
}
}
ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
View file @
544b6739
...
@@ -611,9 +611,6 @@ void process_shaders() {
...
@@ -611,9 +611,6 @@ void process_shaders() {
}
}
for
(
const
auto
&
tname
:
type_names
)
{
for
(
const
auto
&
tname
:
type_names
)
{
if
(
tname
==
"f32"
)
{
continue
;
}
if
(
tname
==
"bf16"
)
continue
;
if
(
tname
==
"bf16"
)
continue
;
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
...
@@ -630,7 +627,7 @@ void process_shaders() {
...
@@ -630,7 +627,7 @@ void process_shaders() {
if
(
tname
==
"f16"
)
{
if
(
tname
==
"f16"
)
{
string_to_spv
(
"flash_attn_f32_f16_"
+
tname
,
"flash_attn_cm1.comp"
,
string_to_spv
(
"flash_attn_f32_f16_"
+
tname
,
"flash_attn_cm1.comp"
,
merge_maps
(
fa_base_dict
,
{{
"Q_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
},
{
"COOPMAT"
,
"1"
}}),
true
,
true
,
false
,
f16acc
);
merge_maps
(
fa_base_dict
,
{{
"Q_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
},
{
"COOPMAT"
,
"1"
}}),
true
,
true
,
false
,
f16acc
);
}
else
if
(
tname
==
"q4_0"
||
tname
==
"q8_0"
)
{
}
else
if
(
tname
==
"q4_0"
||
tname
==
"q8_0"
||
tname
==
"f32"
)
{
std
::
string
data_a_key
=
"DATA_A_"
+
to_uppercase
(
tname
);
std
::
string
data_a_key
=
"DATA_A_"
+
to_uppercase
(
tname
);
string_to_spv
(
"flash_attn_f32_f16_"
+
tname
,
"flash_attn_cm1.comp"
,
string_to_spv
(
"flash_attn_f32_f16_"
+
tname
,
"flash_attn_cm1.comp"
,
merge_maps
(
fa_base_dict
,
{{
data_a_key
,
"1"
},
{
"Q_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
},
{
"BLOCK_SIZE"
,
"QUANT_K_"
+
to_uppercase
(
tname
)},
{
"COOPMAT"
,
"1"
}}),
true
,
true
,
false
,
f16acc
);
merge_maps
(
fa_base_dict
,
{{
data_a_key
,
"1"
},
{
"Q_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
},
{
"BLOCK_SIZE"
,
"QUANT_K_"
+
to_uppercase
(
tname
)},
{
"COOPMAT"
,
"1"
}}),
true
,
true
,
false
,
f16acc
);
...
@@ -639,7 +636,7 @@ void process_shaders() {
...
@@ -639,7 +636,7 @@ void process_shaders() {
if
(
tname
==
"f16"
)
{
if
(
tname
==
"f16"
)
{
string_to_spv
(
"flash_attn_f32_f16_"
+
tname
,
"flash_attn.comp"
,
string_to_spv
(
"flash_attn_f32_f16_"
+
tname
,
"flash_attn.comp"
,
merge_maps
(
fa_base_dict
,
{{
"Q_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
}}),
true
,
false
,
false
,
f16acc
);
merge_maps
(
fa_base_dict
,
{{
"Q_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
}}),
true
,
false
,
false
,
f16acc
);
}
else
if
(
tname
==
"q4_0"
||
tname
==
"q8_0"
)
{
}
else
if
(
tname
==
"q4_0"
||
tname
==
"q8_0"
||
tname
==
"f32"
)
{
std
::
string
data_a_key
=
"DATA_A_"
+
to_uppercase
(
tname
);
std
::
string
data_a_key
=
"DATA_A_"
+
to_uppercase
(
tname
);
string_to_spv
(
"flash_attn_f32_f16_"
+
tname
,
"flash_attn.comp"
,
string_to_spv
(
"flash_attn_f32_f16_"
+
tname
,
"flash_attn.comp"
,
merge_maps
(
fa_base_dict
,
{{
data_a_key
,
"1"
},
{
"Q_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
},
{
"BLOCK_SIZE"
,
"QUANT_K_"
+
to_uppercase
(
tname
)
}}),
true
,
false
,
false
,
f16acc
);
merge_maps
(
fa_base_dict
,
{{
data_a_key
,
"1"
},
{
"Q_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
},
{
"BLOCK_SIZE"
,
"QUANT_K_"
+
to_uppercase
(
tname
)
}}),
true
,
false
,
false
,
f16acc
);
...
@@ -919,6 +916,13 @@ void process_shaders() {
...
@@ -919,6 +916,13 @@ void process_shaders() {
string_to_spv
(
"multi_add_f32"
,
"multi_add.comp"
,
{{
"A_TYPE"
,
"float"
},
{
"B_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
},
{
"FLOAT_TYPE"
,
"float"
},
{
"RTE16"
,
"1"
},
{
"ADD_RMS"
,
"0"
}});
string_to_spv
(
"multi_add_f32"
,
"multi_add.comp"
,
{{
"A_TYPE"
,
"float"
},
{
"B_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
},
{
"FLOAT_TYPE"
,
"float"
},
{
"RTE16"
,
"1"
},
{
"ADD_RMS"
,
"0"
}});
string_to_spv
(
"multi_add_rms_f32"
,
"multi_add.comp"
,
{{
"A_TYPE"
,
"float"
},
{
"B_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
},
{
"FLOAT_TYPE"
,
"float"
},
{
"RTE16"
,
"1"
},
{
"ADD_RMS"
,
"1"
}});
string_to_spv
(
"multi_add_rms_f32"
,
"multi_add.comp"
,
{{
"A_TYPE"
,
"float"
},
{
"B_TYPE"
,
"float"
},
{
"D_TYPE"
,
"float"
},
{
"FLOAT_TYPE"
,
"float"
},
{
"RTE16"
,
"1"
},
{
"ADD_RMS"
,
"1"
}});
string_to_spv
(
"ssm_scan_f32"
,
"ssm_scan.comp"
,
{{
"A_TYPE"
,
"float"
}});
string_to_spv
(
"ssm_scan_subgroup_f32"
,
"ssm_scan.comp"
,
{{
"A_TYPE"
,
"float"
},
{
"USE_SUBGROUP_ADD"
,
"1"
}});
string_to_spv
(
"ssm_conv_f32"
,
"ssm_conv.comp"
,
{{
"A_TYPE"
,
"float"
}});
string_to_spv
(
"topk_moe_f32"
,
"topk_moe.comp"
,
{});
for
(
auto
&
c
:
compiles
)
{
for
(
auto
&
c
:
compiles
)
{
c
.
wait
();
c
.
wait
();
}
}
...
@@ -962,7 +966,7 @@ void write_output_files() {
...
@@ -962,7 +966,7 @@ void write_output_files() {
}
}
std
::
string
suffixes
[
2
]
=
{
"_f32"
,
"_f16"
};
std
::
string
suffixes
[
2
]
=
{
"_f32"
,
"_f16"
};
for
(
auto
op
:
{
"add"
,
"sub"
,
"mul"
,
"div"
,
"add_rms"
})
{
for
(
std
::
string
op
:
{
"add"
,
"sub"
,
"mul"
,
"div"
,
"add_rms"
})
{
hdr
<<
"extern const void * "
<<
op
<<
"_data[2][2][2][2];
\n
"
;
hdr
<<
"extern const void * "
<<
op
<<
"_data[2][2][2][2];
\n
"
;
hdr
<<
"extern const uint64_t "
<<
op
<<
"_len[2][2][2][2];
\n
"
;
hdr
<<
"extern const uint64_t "
<<
op
<<
"_len[2][2][2][2];
\n
"
;
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment