Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
0cefd46f
Unverified
Commit
0cefd46f
authored
May 12, 2025
by
Jeffrey Morgan
Committed by
GitHub
May 12, 2025
Browse files
llama: update to commit de4c07f93 (#10655)
parent
ad035ad5
Changes
113
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1466 additions
and
740 deletions
+1466
-740
ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
...ate-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
+6
-6
ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
...te-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
+6
-6
ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
...ate-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
+6
-6
ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
...ate-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
+6
-6
ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
...ate-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
+6
-6
ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
...ate-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
+6
-6
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
+385
-181
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
+36
-11
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
+263
-126
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
+349
-170
ml/backend/ggml/ggml/src/ggml-opt.cpp
ml/backend/ggml/ggml/src/ggml-opt.cpp
+368
-190
ml/backend/ggml/ggml/src/ggml-quants.c
ml/backend/ggml/ggml/src/ggml-quants.c
+0
-6
ml/backend/ggml/ggml/src/ggml.c
ml/backend/ggml/ggml/src/ggml.c
+29
-20
No files found.
ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
View file @
0cefd46f
...
...
@@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE
(
64
,
4
,
8
);
DECL_FATTN_MMA_F16_CASE
(
80
,
4
,
8
);
DECL_FATTN_MMA_F16_CASE
(
96
,
4
,
8
);
DECL_FATTN_MMA_F16_CASE
(
112
,
4
,
8
);
DECL_FATTN_MMA_F16_CASE
(
128
,
4
,
8
);
DECL_FATTN_MMA_F16_CASE
(
256
,
4
,
8
);
DECL_FATTN_MMA_F16_CASE
(
64
,
64
,
4
,
8
);
DECL_FATTN_MMA_F16_CASE
(
80
,
80
,
4
,
8
);
DECL_FATTN_MMA_F16_CASE
(
96
,
96
,
4
,
8
);
DECL_FATTN_MMA_F16_CASE
(
112
,
112
,
4
,
8
);
DECL_FATTN_MMA_F16_CASE
(
128
,
128
,
4
,
8
);
DECL_FATTN_MMA_F16_CASE
(
256
,
256
,
4
,
8
);
ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
View file @
0cefd46f
...
...
@@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE
(
64
,
64
,
1
);
DECL_FATTN_MMA_F16_CASE
(
80
,
64
,
1
);
DECL_FATTN_MMA_F16_CASE
(
96
,
64
,
1
);
DECL_FATTN_MMA_F16_CASE
(
112
,
64
,
1
);
DECL_FATTN_MMA_F16_CASE
(
128
,
64
,
1
);
DECL_FATTN_MMA_F16_CASE
(
256
,
64
,
1
);
DECL_FATTN_MMA_F16_CASE
(
64
,
64
,
64
,
1
);
DECL_FATTN_MMA_F16_CASE
(
80
,
80
,
64
,
1
);
DECL_FATTN_MMA_F16_CASE
(
96
,
96
,
64
,
1
);
DECL_FATTN_MMA_F16_CASE
(
112
,
112
,
64
,
1
);
DECL_FATTN_MMA_F16_CASE
(
128
,
128
,
64
,
1
);
DECL_FATTN_MMA_F16_CASE
(
256
,
256
,
64
,
1
);
ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
View file @
0cefd46f
...
...
@@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE
(
64
,
8
,
1
);
DECL_FATTN_MMA_F16_CASE
(
80
,
8
,
1
);
DECL_FATTN_MMA_F16_CASE
(
96
,
8
,
1
);
DECL_FATTN_MMA_F16_CASE
(
112
,
8
,
1
);
DECL_FATTN_MMA_F16_CASE
(
128
,
8
,
1
);
DECL_FATTN_MMA_F16_CASE
(
256
,
8
,
1
);
DECL_FATTN_MMA_F16_CASE
(
64
,
64
,
8
,
1
);
DECL_FATTN_MMA_F16_CASE
(
80
,
80
,
8
,
1
);
DECL_FATTN_MMA_F16_CASE
(
96
,
96
,
8
,
1
);
DECL_FATTN_MMA_F16_CASE
(
112
,
112
,
8
,
1
);
DECL_FATTN_MMA_F16_CASE
(
128
,
128
,
8
,
1
);
DECL_FATTN_MMA_F16_CASE
(
256
,
256
,
8
,
1
);
ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
View file @
0cefd46f
...
...
@@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE
(
64
,
8
,
2
);
DECL_FATTN_MMA_F16_CASE
(
80
,
8
,
2
);
DECL_FATTN_MMA_F16_CASE
(
96
,
8
,
2
);
DECL_FATTN_MMA_F16_CASE
(
112
,
8
,
2
);
DECL_FATTN_MMA_F16_CASE
(
128
,
8
,
2
);
DECL_FATTN_MMA_F16_CASE
(
256
,
8
,
2
);
DECL_FATTN_MMA_F16_CASE
(
64
,
64
,
8
,
2
);
DECL_FATTN_MMA_F16_CASE
(
80
,
80
,
8
,
2
);
DECL_FATTN_MMA_F16_CASE
(
96
,
96
,
8
,
2
);
DECL_FATTN_MMA_F16_CASE
(
112
,
112
,
8
,
2
);
DECL_FATTN_MMA_F16_CASE
(
128
,
128
,
8
,
2
);
DECL_FATTN_MMA_F16_CASE
(
256
,
256
,
8
,
2
);
ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
View file @
0cefd46f
...
...
@@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE
(
64
,
8
,
4
);
DECL_FATTN_MMA_F16_CASE
(
80
,
8
,
4
);
DECL_FATTN_MMA_F16_CASE
(
96
,
8
,
4
);
DECL_FATTN_MMA_F16_CASE
(
112
,
8
,
4
);
DECL_FATTN_MMA_F16_CASE
(
128
,
8
,
4
);
DECL_FATTN_MMA_F16_CASE
(
256
,
8
,
4
);
DECL_FATTN_MMA_F16_CASE
(
64
,
64
,
8
,
4
);
DECL_FATTN_MMA_F16_CASE
(
80
,
80
,
8
,
4
);
DECL_FATTN_MMA_F16_CASE
(
96
,
96
,
8
,
4
);
DECL_FATTN_MMA_F16_CASE
(
112
,
112
,
8
,
4
);
DECL_FATTN_MMA_F16_CASE
(
128
,
128
,
8
,
4
);
DECL_FATTN_MMA_F16_CASE
(
256
,
256
,
8
,
4
);
ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
View file @
0cefd46f
...
...
@@ -2,9 +2,9 @@
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE
(
64
,
8
,
8
);
DECL_FATTN_MMA_F16_CASE
(
80
,
8
,
8
);
DECL_FATTN_MMA_F16_CASE
(
96
,
8
,
8
);
DECL_FATTN_MMA_F16_CASE
(
112
,
8
,
8
);
DECL_FATTN_MMA_F16_CASE
(
128
,
8
,
8
);
DECL_FATTN_MMA_F16_CASE
(
256
,
8
,
8
);
DECL_FATTN_MMA_F16_CASE
(
64
,
64
,
8
,
8
);
DECL_FATTN_MMA_F16_CASE
(
80
,
80
,
8
,
8
);
DECL_FATTN_MMA_F16_CASE
(
96
,
96
,
8
,
8
);
DECL_FATTN_MMA_F16_CASE
(
112
,
112
,
8
,
8
);
DECL_FATTN_MMA_F16_CASE
(
128
,
128
,
8
,
8
);
DECL_FATTN_MMA_F16_CASE
(
256
,
256
,
8
,
8
);
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
View file @
0cefd46f
...
...
@@ -2071,6 +2071,10 @@ typedef struct {
float attn_factor;
float beta_fast;
float beta_slow;
int32_t sect_0;
int32_t sect_1;
int32_t sect_2;
int32_t sect_3;
} ggml_metal_kargs_rope;
typedef struct {
...
...
@@ -2163,21 +2167,42 @@ typedef struct {
} ggml_metal_kargs_mul_mv_ext;
typedef struct {
int32_t nei0;
int32_t nei1;
uint64_t nbi1;
int32_t ne10;
int32_t ne11; // n_expert_used (bcast)
uint64_t nb11;
uint64_t nb12;
int32_t neh11; // n_tokens
uint64_t nbh11;
int32_t ne20; // n_expert_used
uint64_t nb21;
} ggml_metal_kargs_mul_mm_id_map0;
typedef struct {
int32_t ne20; // n_expert_used
int32_t neh0;
int32_t neh1;
uint64_t nbh1;
uint64_t nbh2;
int32_t ne0;
uint64_t nb1;
uint64_t nb2;
} ggml_metal_kargs_mul_mm_id_map1;
typedef struct {
int32_t ne00;
int32_t ne02;
uint64_t nb01;
uint64_t nb02;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
int32_t ne0;
int32_t ne1;
uint64_t nb03;
int32_t neh12;
uint64_t nbh10;
uint64_t nbh11;
uint64_t nbh12;
uint64_t nbh13;
int32_t neh0;
int32_t neh1;
int16_t r2;
int16_t r3;
} ggml_metal_kargs_mul_mm_id;
typedef struct {
...
...
@@ -5166,8 +5191,148 @@ kernel void kernel_rope_neox(
}
}
template<typename T>
kernel void kernel_rope_multi(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
// mrope theta calculations
// note: the rest is the same as kernel_rope_neox
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims;
float theta_base;
if (sector < args.sect_0) {
theta_base = (float) pos[i2];
} else if (sector < sec_w01) {
theta_base = (float) pos[i2 + args.ne02];
} else if (sector < sec_w012) {
theta_base = (float) pos[i2 + args.ne02 * 2];
} else {
theta_base = (float) pos[i2 + args.ne02 * 3];
}
// end of mrope
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_vision(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
const int ic = i0/2;
// mrope theta calculations (only support 2 dimensions)
const int sect_dims = args.sect_0 + args.sect_1;
const int sector = ic % sect_dims;
float p;
float theta_base;
if (sector < args.sect_1) {
p = (float) sector;
theta_base = (float) pos[i2];
} else {
p = (float) sector - args.sect_0;
theta_base = (float) pos[i2 + args.ne02];
}
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
// end of mrope
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
...
...
@@ -5175,6 +5340,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
typedef void (im2col_t)(
device const float * x,
device char * dst,
...
...
@@ -8834,127 +9005,219 @@ kernel void kernel_mul_mm(
}
}
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
// TODO: this kernel needs to be reimplemented from scratch for better performance
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
void kernel_mul_mm_id_impl(
int32_t ne00,
int32_t ne02,
uint64_t nb01,
uint64_t nb02,
int32_t ne11,
int32_t ne12,
uint64_t nb10,
uint64_t nb11,
uint64_t nb12,
int32_t ne0,
int32_t ne1,
int64_t ne0ne1,
template<typename T4>
kernel void kernel_mul_mm_id_map0(
constant ggml_metal_kargs_mul_mm_id_map0 & args,
device const char * src1,
device const char * src2,
device char * hsrc1,
device char * htpe,
device char * hids,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ide = tgpig[0]; // expert id
int n_all = 0;
device int32_t * ids_i32 = (device int32_t *) (hids);
for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
if (src2_i32[i20] != ide) {
continue;
}
device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
}
if (tpitg.x == 0) {
ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
}
++n_all;
}
}
if (tpitg.x == 0) {
device int32_t * tpe_i32 = (device int32_t *) (htpe);
tpe_i32[ide] = n_all;
}
}
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
template<typename T>
kernel void kernel_mul_mm_id_map1(
constant ggml_metal_kargs_mul_mm_id_map1 & args,
device const char * hdst,
device const char * hids,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i20 = tgpig[0]; // used expert
const int i21 = tgpig[1]; // token
device const int32_t * ids_i32 = (device const int32_t *) (hids);
device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
const int id = ids_i32[i21*args.ne20 + i20];
const int ide = id / args.neh1;
const int idt = id % args.neh1;
device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
dst_f32x4[i0] = hdst_f32x4[i0];
}
}
typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0,
device const char * src1,
threadgroup ushort2 * rowids
,
device const char * tpe
,
device char * dst,
threadgroup
char * shmem,
threadgroup char * shmem
[[threadgroup(0)]]
,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup
half
* sa = (threadgroup
half
*)(shmem);
threadgroup
float
* sb = (threadgroup
float
*)(shmem + 4096);
threadgroup
T
* sa = (threadgroup
T
*)(shmem);
threadgroup
half
* sb = (threadgroup
half
*)(shmem + 4096);
const int r0 = tgpig.y;
const int r1 = tgpig.x;
const int im = tgpig.z;
device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
if (r1*BLOCK_SIZE_N >= ne1) return;
const int neh1 = tpe_i32[im];
if (r1*BLOCK_SIZE_N >= neh1) {
return;
}
// if this block is of 64x32 shape or smaller
short n_rows = (ne0 - r0
*
BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0
*
BLOCK_SIZE_M) : BLOCK_SIZE_M;
short n_cols = (ne1 - r1
*
BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1
*
BLOCK_SIZE_N) : BLOCK_SIZE_N;
const
short n_rows = (
args.
ne
h
0 - r0
*
BLOCK_SIZE_M < BLOCK_SIZE_M) ? (
args.
ne
h
0 - r0
*
BLOCK_SIZE_M) : BLOCK_SIZE_M;
const
short n_cols = (
ne
h
1 - r1
*
BLOCK_SIZE_N < BLOCK_SIZE_N) ? (
ne
h
1 - r1
*
BLOCK_SIZE_N) : BLOCK_SIZE_N;
// a thread shouldn't load data outside of the matrix
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
const
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
const
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
simdgroup_
half8x8
ma[4];
simdgroup_
float
8x8 mb[2];
simdgroup_
T8x8
ma[4];
simdgroup_
half
8x8
mb[2];
simdgroup_float8x8 mc[8];
for (int i = 0; i < 8; i++){
for (short i = 0; i < 8; i++){
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
short il = (tiitg % THREAD_PER_ROW);
ushort offset1 = il/nl;
const int i12 = im%args.neh12;
const int i13 = im/args.neh12;
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const short offset1 = il/nl;
threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
device const block_q * x = (device const block_q *)(src0
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
device const
block_q
*
x
= (device const
block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
device const float * y = (device const float *)(src1
+ nb12
* id[1]
+ nb11
* (id[0] % ne11
)
+ nb10
*
(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
device const
half
*
y
= (device const
half *)(src1
+ args.nbh13*i13
+
args.
nb
h
12
*i12
+
args.
nb
h
11
*(r1*BLOCK_SIZE_N + thread_col
)
+
args.
nb
h
10
*
(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
for (int loop_k = 0; loop_k <
args.
ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
half
4x4 temp_a;
T
4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
#pragma unroll(16)
for (short i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
}
*(threadgroup
float
2x4 *)(sb + (tiitg
%
THREAD_PER_COL)
* 8 * 32 + 8 *
(tiitg
/
THREAD_PER_COL)) = *((device
float
2x4 *)y);
*(threadgroup
half
2x4 *)(sb +
32*8*
(tiitg
%
THREAD_PER_COL)
+ 8*
(tiitg
/
THREAD_PER_COL)) = *((device
half
2x4 *)
y);
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2
+nl-
1)/nl : x;
x = (il < 2) ? x + (2
+ nl -
1)/nl : x;
y += BLOCK_SIZE_K;
threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products
threadgroup
half
* lsma = (sa + THREAD_MAT_M
*
SG_MAT_SIZE
*
(sgitg
%
2));
threadgroup
float
* lsmb = (sb + THREAD_MAT_N
*
SG_MAT_SIZE
*
(sgitg
/
2));
threadgroup
const T
* lsma = (sa + THREAD_MAT_M
*
SG_MAT_SIZE
*
(sgitg
%
2));
threadgroup
const half
* lsmb = (sb + THREAD_MAT_N
*
SG_MAT_SIZE
*
(sgitg
/
2));
#pragma unroll(BLOCK_SIZE_K/8)
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
#pragma unroll(4)
for (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
}
simdgroup_barrier(mem_flags::mem_none);
#pragma unroll(2)
for (
in
t i = 0; i < 2; i++) {
for (
shor
t i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
}
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
#pragma unroll(8)
for (
in
t i = 0; i < 8; i++){
for (
shor
t i = 0; i < 8; i++){
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
}
}
{
if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
device float * C = (device float *) dst +
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
}
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ 32
*
(sgitg&1) + (16
*
(sgitg>>1))
*
BLOCK_SIZE_M;
for (
in
t i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8
*
(i%4) + 8
*
BLOCK_SIZE_M
*
(i/4), BLOCK_SIZE_M);
+ 32
*
(sgitg&1) + (16
*
(sgitg
>>
1))
*
BLOCK_SIZE_M;
for (
shor
t i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8
*
(i%4) + 8
*
BLOCK_SIZE_M
*
(i/4), BLOCK_SIZE_M);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
...
...
@@ -8974,66 +9237,6 @@ void kernel_mul_mm_id_impl(
}
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0s,
device const char * src1,
device char * dst,
device const char * ids,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const int32_t i02 = tgpig.z;
tgpig.z = 0;
device const char * src0 = src0s + i02*args.nb02;
// row indices
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
// TODO: parallelize this loop
int32_t _ne1 = 0;
for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
if (id == i02) {
if (tiitg == 0) {
rowids[_ne1] = ushort2(ii0, ii1);
}
_ne1++;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
args.ne00,
args.ne02,
args.nb01,
args.nb02,
args.ne11,
args.ne12,
args.nb10,
args.nb11,
args.nb12,
args.ne0,
_ne1,
(int64_t)args.ne0*args.ne1,
src0,
src1,
rowids,
dst,
shmem,
tgpig,
tiitg,
sgitg);
}
#define QK_NL 16
//
...
...
@@ -9074,63 +9277,64 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get
// matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) m
at
_mm_t;
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) m
ul
_mm_t;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel m
at
_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel m
ul
_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
#endif
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
//
// indirect matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) m
at
_mm_id
_t
;
typedef decltype(kernel_mul_mm_id<
half, half4x4, simdgroup_half8x8,
float4x4, 1, dequantize_f32>) m
ul
_mm_id;
template [[host_name("kernel_mul_mm_id_f32_f
32
")]] kernel m
at
_mm_id
_t
kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_id_f16_f
32
")]] kernel m
at
_mm_id
_t
kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_id_f32_f
16
")]] kernel m
ul
_mm_id kernel_mul_mm_id<
half, half4x4, simdgroup_half8x8,
float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_id_f16_f
16
")]] kernel m
ul
_mm_id kernel_mul_mm_id<
half, half4x4, simdgroup_half8x8,
half4x4, 1, dequantize_f16>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_mul_mm_id_bf16_f
32
")]] kernel m
at
_mm_id
_t
kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
template [[host_name("kernel_mul_mm_id_bf16_f
16
")]] kernel m
ul
_mm_id kernel_mul_mm_id<
bfloat, bfloat4x4, simdgroup_bfloat8x8,
bfloat4x4, 1, dequantize_bf16>;
#endif
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
//
// matrix-vector multiplication
...
...
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h
View file @
0cefd46f
...
...
@@ -207,6 +207,10 @@ typedef struct {
float
attn_factor
;
float
beta_fast
;
float
beta_slow
;
int32_t
sect_0
;
int32_t
sect_1
;
int32_t
sect_2
;
int32_t
sect_3
;
}
ggml_metal_kargs_rope
;
typedef
struct
{
...
...
@@ -299,21 +303,42 @@ typedef struct {
}
ggml_metal_kargs_mul_mv_ext
;
typedef
struct
{
int32_t
nei0
;
int32_t
nei1
;
uint64_t
nbi1
;
int32_t
ne10
;
int32_t
ne11
;
// n_expert_used (bcast)
uint64_t
nb11
;
uint64_t
nb12
;
int32_t
neh11
;
// n_tokens
uint64_t
nbh11
;
int32_t
ne20
;
// n_expert_used
uint64_t
nb21
;
}
ggml_metal_kargs_mul_mm_id_map0
;
typedef
struct
{
int32_t
ne20
;
// n_expert_used
int32_t
neh0
;
int32_t
neh1
;
uint64_t
nbh1
;
uint64_t
nbh2
;
int32_t
ne0
;
uint64_t
nb1
;
uint64_t
nb2
;
}
ggml_metal_kargs_mul_mm_id_map1
;
typedef
struct
{
int32_t
ne00
;
int32_t
ne02
;
uint64_t
nb01
;
uint64_t
nb02
;
int32_t
ne11
;
int32_t
ne12
;
int32_t
ne13
;
uint64_t
nb10
;
uint64_t
nb11
;
uint64_t
nb12
;
int32_t
ne0
;
int32_t
ne1
;
uint64_t
nb03
;
int32_t
neh12
;
uint64_t
nbh10
;
uint64_t
nbh11
;
uint64_t
nbh12
;
uint64_t
nbh13
;
int32_t
neh0
;
int32_t
neh1
;
int16_t
r2
;
int16_t
r3
;
}
ggml_metal_kargs_mul_mm_id
;
typedef
struct
{
...
...
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.m
View file @
0cefd46f
...
...
@@ -306,30 +306,36 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16
,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16
,
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32
,
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16
,
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32
,
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16
,
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32
,
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16
,
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32
,
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16
,
GGML_METAL_KERNEL_TYPE_IM2COL_F16
,
...
...
@@ -651,7 +657,8 @@ static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
}
if
(
mem_pool
->
heaps_to_remove
.
count
>
0
)
{
for
(
NSUInteger
i
=
0
;
i
<
[
mem_pool
->
heaps_to_remove
count
];
i
++
)
{
// remove in reverse order
for
(
NSUInteger
i
=
[
mem_pool
->
heaps_to_remove
count
]
-
1
;
;
--
i
)
{
NSUInteger
index
=
[[
mem_pool
->
heaps_to_remove
objectAtIndex
:
i
]
intValue
];
ggml_metal_heap_ptr
*
ptr
=
[
mem_pool
->
heaps
objectAtIndex
:
index
];
...
...
@@ -660,6 +667,10 @@ static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
[
mem_pool
->
heaps
removeObjectAtIndex
:
index
];
[
ptr
release
];
if
(
i
==
0
)
{
break
;
}
}
[
mem_pool
->
heaps_to_remove
removeAllObjects
];
...
...
@@ -673,7 +684,7 @@ static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
}
static
id
<
MTLBuffer
>
ggml_metal_mem_pool_alloc
(
struct
ggml_metal_mem_pool
*
mem_pool
,
size_t
size
)
{
const
size_t
alignment
=
3
2
;
const
size_t
alignment
=
2
56
;
const
size_t
size_aligned
=
GGML_PAD
(
size
,
alignment
);
...
...
@@ -1243,30 +1254,36 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32
,
mul_mm_iq1_m_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32
,
mul_mm_iq4_nl_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32
,
mul_mm_iq4_xs_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32
,
mul_mm_id_f32_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32
,
mul_mm_id_f16_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32
,
mul_mm_id_bf16_f32
,
has_simdgroup_mm
&&
use_bfloat
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32
,
mul_mm_id_q4_0_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32
,
mul_mm_id_q4_1_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32
,
mul_mm_id_q5_0_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32
,
mul_mm_id_q5_1_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32
,
mul_mm_id_q8_0_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32
,
mul_mm_id_q2_K_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32
,
mul_mm_id_q3_K_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32
,
mul_mm_id_q4_K_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32
,
mul_mm_id_q5_K_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32
,
mul_mm_id_q6_K_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32
,
mul_mm_id_iq2_xxs_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32
,
mul_mm_id_iq2_xs_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32
,
mul_mm_id_iq3_xxs_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32
,
mul_mm_id_iq3_s_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32
,
mul_mm_id_iq2_s_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32
,
mul_mm_id_iq1_s_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32
,
mul_mm_id_iq1_m_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32
,
mul_mm_id_iq4_nl_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32
,
mul_mm_id_iq4_xs_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16
,
mul_mm_id_map0_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32
,
mul_mm_id_map1_f32
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16
,
mul_mm_id_f32_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16
,
mul_mm_id_f16_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16
,
mul_mm_id_bf16_f16
,
has_simdgroup_mm
&&
use_bfloat
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F16
,
mul_mm_id_q4_0_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F16
,
mul_mm_id_q4_1_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16
,
mul_mm_id_q5_0_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16
,
mul_mm_id_q5_1_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16
,
mul_mm_id_q8_0_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16
,
mul_mm_id_q2_K_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16
,
mul_mm_id_q3_K_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16
,
mul_mm_id_q4_K_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F16
,
mul_mm_id_q5_K_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F16
,
mul_mm_id_q6_K_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F16
,
mul_mm_id_iq2_xxs_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F16
,
mul_mm_id_iq2_xs_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F16
,
mul_mm_id_iq3_xxs_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F16
,
mul_mm_id_iq3_s_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F16
,
mul_mm_id_iq2_s_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F16
,
mul_mm_id_iq1_s_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16
,
mul_mm_id_iq1_m_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16
,
mul_mm_id_iq4_nl_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16
,
mul_mm_id_iq4_xs_f16
,
has_simdgroup_mm
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32
,
rope_norm_f32
,
true
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16
,
rope_norm_f16
,
true
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32
,
rope_multi_f32
,
true
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16
,
rope_multi_f16
,
true
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32
,
rope_vision_f32
,
true
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16
,
rope_vision_f16
,
true
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32
,
rope_neox_f32
,
true
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16
,
rope_neox_f16
,
true
);
GGML_METAL_ADD_KERNEL
(
GGML_METAL_KERNEL_TYPE_IM2COL_F16
,
im2col_f16
,
true
);
...
...
@@ -1630,16 +1647,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case
GGML_OP_NORM
:
return
has_simdgroup_reduction
&&
(
op
->
ne
[
0
]
%
4
==
0
&&
ggml_is_contiguous_1
(
op
->
src
[
0
]));
case
GGML_OP_ROPE
:
{
const
int
mode
=
((
const
int32_t
*
)
op
->
op_params
)[
2
];
if
(
mode
&
GGML_ROPE_TYPE_MROPE
)
{
return
false
;
}
if
(
mode
&
GGML_ROPE_TYPE_VISION
)
{
return
false
;
}
return
true
;
}
case
GGML_OP_IM2COL
:
return
op
->
src
[
0
]
->
type
==
GGML_TYPE_F16
;
case
GGML_OP_POOL_1D
:
...
...
@@ -3002,7 +3010,7 @@ static bool ggml_metal_encode_node(
[
encoder
setBuffer
:
id_dst
offset
:
offs_dst
atIndex
:
3
];
[
encoder
setThreadgroupMemoryLength
:
8192
atIndex
:
0
];
[
encoder
dispatchThreadgroups
:
MTLSizeMake
(
(
ne11
+
31
)
/
32
,
(
ne01
+
63
)
/
64
,
ne12
*
ne13
)
threadsPerThreadgroup
:
MTLSizeMake
(
128
,
1
,
1
)];
[
encoder
dispatchThreadgroups
:
MTLSizeMake
((
ne11
+
31
)
/
32
,
(
ne01
+
63
)
/
64
,
ne12
*
ne13
)
threadsPerThreadgroup
:
MTLSizeMake
(
128
,
1
,
1
)];
}
else
{
id
<
MTLComputePipelineState
>
pipeline
=
nil
;
...
...
@@ -3222,8 +3230,6 @@ static bool ggml_metal_encode_node(
}
break
;
case
GGML_OP_MUL_MAT_ID
:
{
const
int
n_as
=
src0
->
ne
[
2
];
// src2 = ids
const
enum
ggml_type
src2t
=
src2
->
type
;
GGML_UNUSED
(
src2t
);
...
...
@@ -3237,24 +3243,21 @@ static bool ggml_metal_encode_node(
GGML_ASSERT
(
ne03
==
1
);
GGML_ASSERT
(
ne13
==
1
);
const
uint32_t
r2
=
1
;
const
uint32_t
r3
=
1
;
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
// ne20 = n_used_experts
// ne21 = n_rows
const
int
dst_rows
=
ne20
*
ne21
;
const
int
dst_rows_min
=
n_as
;
const
int
dst_rows_max
=
(
device
.
maxThreadgroupMemoryLength
/
2
-
8192
)
/
4
;
// max size of the rowids array in the kernel shared buffer
//GGML_ASSERT(dst_rows <= dst_rows_max);
// ne21 = n_rows (batch size)
const
int
ne21_mm_id_min
=
32
;
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if
([
device
supportsFamily
:
MTLGPUFamilyApple7
]
&&
ne00
%
32
==
0
&&
ne00
>=
64
&&
//ne01 / ne02 >= 512 && // NOTE: this is based on Mixtral shapes, might need adjustments
dst_rows
>
dst_rows_min
&&
dst_rows
<=
dst_rows_max
)
{
(
ne21
>=
ne21_mm_id_min
))
{
GGML_ASSERT
(
ne00
%
4
==
0
);
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
...
...
@@ -3265,62 +3268,169 @@ static bool ggml_metal_encode_node(
default:
break
;
}
const
int64_t
neh10
=
ne10
;
// n_embd
const
int64_t
neh11
=
ne21
;
// n_tokens
const
int64_t
neh12
=
ne02
;
// n_expert
const
uint64_t
nbh10
=
ggml_type_size
(
GGML_TYPE_F16
);
const
uint64_t
nbh11
=
nbh10
*
neh10
;
const
uint64_t
nbh12
=
nbh11
*
neh11
;
const
uint64_t
nbh13
=
nbh12
*
neh12
;
const
size_t
s_src1
=
ggml_type_size
(
GGML_TYPE_F16
)
*
neh10
*
neh11
*
neh12
;
id
<
MTLBuffer
>
h_src1
=
ggml_metal_mem_pool_alloc
(
mem_pool
,
s_src1
);
if
(
!
h_src1
)
{
GGML_LOG_ERROR
(
"%s: failed to allocate buffer from memory pool, size = %zu
\n
"
,
__func__
,
s_src1
);
return
false
;
}
const
int64_t
neh0
=
ne0
;
const
int64_t
neh1
=
ne21
;
const
int64_t
neh2
=
ne02
;
const
uint64_t
nbh0
=
ggml_type_size
(
GGML_TYPE_F32
);
const
uint64_t
nbh1
=
nbh0
*
neh0
;
const
uint64_t
nbh2
=
nbh1
*
neh1
;
//const uint64_t nbh3 = nbh2*neh2;
const
size_t
s_dst
=
ggml_type_size
(
GGML_TYPE_F32
)
*
neh0
*
neh1
*
neh2
;
id
<
MTLBuffer
>
h_dst
=
ggml_metal_mem_pool_alloc
(
mem_pool
,
s_dst
);
if
(
!
h_dst
)
{
GGML_LOG_ERROR
(
"%s: failed to allocate buffer from memory pool, size = %zu
\n
"
,
__func__
,
s_dst
);
return
false
;
}
// tokens per expert
const
size_t
s_tpe
=
ggml_type_size
(
GGML_TYPE_I32
)
*
ne02
;
id
<
MTLBuffer
>
h_tpe
=
ggml_metal_mem_pool_alloc
(
mem_pool
,
s_tpe
);
if
(
!
h_tpe
)
{
GGML_LOG_ERROR
(
"%s: failed to allocate buffer from memory pool, size = %zu
\n
"
,
__func__
,
s_tpe
);
return
false
;
}
// id map
// [n_expert_used, n_tokens]
const
size_t
s_ids
=
ggml_type_size
(
GGML_TYPE_I32
)
*
ne20
*
ne21
;
id
<
MTLBuffer
>
h_ids
=
ggml_metal_mem_pool_alloc
(
mem_pool
,
s_ids
);
if
(
!
h_ids
)
{
GGML_LOG_ERROR
(
"%s: failed to allocate buffer from memory pool, size = %zu
\n
"
,
__func__
,
s_ids
);
return
false
;
}
{
const
int
nth
=
MIN
(
1024
,
ne10
/
4
);
ggml_metal_kargs_mul_mm_id_map0
args
=
{
ne10
,
ne11
,
// n_expert_used (bcast)
nb11
,
nb12
,
neh11
,
// n_tokens
nbh11
,
ne20
,
// n_expert_used
nb21
,
};
id
<
MTLComputePipelineState
>
pipeline
=
nil
;
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16
].
pipeline
;
[
encoder
setComputePipelineState
:
pipeline
];
[
encoder
setBytes
:
&
args
length
:
sizeof
(
args
)
atIndex
:
0
];
[
encoder
setBuffer
:
id_src1
offset
:
offs_src1
atIndex
:
1
];
[
encoder
setBuffer
:
id_src2
offset
:
offs_src2
atIndex
:
2
];
[
encoder
setBuffer
:
h_src1
offset
:
0
atIndex
:
3
];
[
encoder
setBuffer
:
h_tpe
offset
:
0
atIndex
:
4
];
[
encoder
setBuffer
:
h_ids
offset
:
0
atIndex
:
5
];
[
encoder
dispatchThreadgroups
:
MTLSizeMake
(
ne02
,
1
,
1
)
threadsPerThreadgroup
:
MTLSizeMake
(
nth
,
1
,
1
)];
}
{
id
<
MTLComputePipelineState
>
pipeline
=
nil
;
switch
(
src0
->
type
)
{
case
GGML_TYPE_F32
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F
32
].
pipeline
;
break
;
case
GGML_TYPE_F16
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F
32
].
pipeline
;
break
;
case
GGML_TYPE_BF16
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F
32
].
pipeline
;
break
;
case
GGML_TYPE_Q4_0
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F
32
].
pipeline
;
break
;
case
GGML_TYPE_Q4_1
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F
32
].
pipeline
;
break
;
case
GGML_TYPE_Q5_0
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F
32
].
pipeline
;
break
;
case
GGML_TYPE_Q5_1
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F
32
].
pipeline
;
break
;
case
GGML_TYPE_Q8_0
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F
32
].
pipeline
;
break
;
case
GGML_TYPE_Q2_K
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F
32
].
pipeline
;
break
;
case
GGML_TYPE_Q3_K
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F
32
].
pipeline
;
break
;
case
GGML_TYPE_Q4_K
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F
32
].
pipeline
;
break
;
case
GGML_TYPE_Q5_K
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F
32
].
pipeline
;
break
;
case
GGML_TYPE_Q6_K
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F
32
].
pipeline
;
break
;
case
GGML_TYPE_IQ2_XXS
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F
32
].
pipeline
;
break
;
case
GGML_TYPE_IQ2_XS
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F
32
].
pipeline
;
break
;
case
GGML_TYPE_IQ3_XXS
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F
32
].
pipeline
;
break
;
case
GGML_TYPE_IQ3_S
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F
32
].
pipeline
;
break
;
case
GGML_TYPE_IQ2_S
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F
32
].
pipeline
;
break
;
case
GGML_TYPE_IQ1_S
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F
32
].
pipeline
;
break
;
case
GGML_TYPE_IQ1_M
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F
32
].
pipeline
;
break
;
case
GGML_TYPE_IQ4_NL
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F
32
].
pipeline
;
break
;
case
GGML_TYPE_IQ4_XS
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F
32
].
pipeline
;
break
;
case
GGML_TYPE_F32
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F
16
].
pipeline
;
break
;
case
GGML_TYPE_F16
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F
16
].
pipeline
;
break
;
case
GGML_TYPE_BF16
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F
16
].
pipeline
;
break
;
case
GGML_TYPE_Q4_0
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F
16
].
pipeline
;
break
;
case
GGML_TYPE_Q4_1
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F
16
].
pipeline
;
break
;
case
GGML_TYPE_Q5_0
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F
16
].
pipeline
;
break
;
case
GGML_TYPE_Q5_1
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F
16
].
pipeline
;
break
;
case
GGML_TYPE_Q8_0
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F
16
].
pipeline
;
break
;
case
GGML_TYPE_Q2_K
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F
16
].
pipeline
;
break
;
case
GGML_TYPE_Q3_K
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F
16
].
pipeline
;
break
;
case
GGML_TYPE_Q4_K
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F
16
].
pipeline
;
break
;
case
GGML_TYPE_Q5_K
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F
16
].
pipeline
;
break
;
case
GGML_TYPE_Q6_K
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F
16
].
pipeline
;
break
;
case
GGML_TYPE_IQ2_XXS
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F
16
].
pipeline
;
break
;
case
GGML_TYPE_IQ2_XS
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F
16
].
pipeline
;
break
;
case
GGML_TYPE_IQ3_XXS
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F
16
].
pipeline
;
break
;
case
GGML_TYPE_IQ3_S
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F
16
].
pipeline
;
break
;
case
GGML_TYPE_IQ2_S
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F
16
].
pipeline
;
break
;
case
GGML_TYPE_IQ1_S
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F
16
].
pipeline
;
break
;
case
GGML_TYPE_IQ1_M
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F
16
].
pipeline
;
break
;
case
GGML_TYPE_IQ4_NL
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F
16
].
pipeline
;
break
;
case
GGML_TYPE_IQ4_XS
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F
16
].
pipeline
;
break
;
default:
GGML_ABORT
(
"MUL_MAT_ID not implemented"
);
}
ggml_metal_kargs_mul_mm_id
args
=
{
/*.nei0 =*/
ne20
,
/*.nei1 =*/
ne21
,
/*.nbi1 =*/
nb21
,
/*.ne00 =*/
ne00
,
/*.ne02 =*/
ne02
,
/*.nb01 =*/
nb01
,
/*.nb02 =*/
nb02
,
/*.ne11 =*/
ne11
,
/*.ne12 =*/
ne12
,
/*.ne13 =*/
ne13
,
/*.nb10 =*/
nb10
,
/*.nb11 =*/
nb11
,
/*.nb12 =*/
nb12
,
/*.ne0 =*/
ne0
,
/*.ne1 =*/
ne1
,
/*.nb03 =*/
nb03
,
/*.neh12 =*/
neh12
,
/*.nbh10 =*/
nbh10
,
/*.nbh11 =*/
nbh11
,
/*.nbh12 =*/
nbh12
,
/*.nbh13 =*/
nbh13
,
/*.neh0 =*/
neh0
,
/*.neh1 =*/
neh1
,
/*.r2 =*/
r2
,
/*.r3 =*/
r3
,
};
[
encoder
setComputePipelineState
:
pipeline
];
[
encoder
setBytes
:
&
args
length
:
sizeof
(
args
)
atIndex
:
0
];
[
encoder
setBuffer
:
id_src0
offset
:
offs_src0
atIndex
:
1
];
[
encoder
setBuffer
:
id_src1
offset
:
offs_src1
atIndex
:
2
];
[
encoder
setBuffer
:
id_dst
offset
:
offs_dst
atIndex
:
3
];
[
encoder
setBuffer
:
id_src2
offset
:
offs_src2
atIndex
:
4
];
[
encoder
setBuffer
:
h_src1
offset
:
0
atIndex
:
2
];
[
encoder
setBuffer
:
h_tpe
offset
:
0
atIndex
:
3
];
[
encoder
setBuffer
:
h_dst
offset
:
0
atIndex
:
4
];
[
encoder
setThreadgroupMemoryLength
:
8192
atIndex
:
0
];
[
encoder
dispatchThreadgroups
:
MTLSizeMake
((
ne21
+
31
)
/
32
,
(
ne01
+
63
)
/
64
,
ne02
)
threadsPerThreadgroup
:
MTLSizeMake
(
128
,
1
,
1
)];
}
{
GGML_ASSERT
(
ne0
%
4
==
0
);
const
int
nth
=
MIN
(
1024
,
ne0
/
4
);
ggml_metal_kargs_mul_mm_id_map1
args
=
{
ne20
,
// n_expert_used
neh0
,
neh1
,
nbh1
,
nbh2
,
ne0
,
nb1
,
nb2
,
};
[
encoder
setThreadgroupMemoryLength
:
GGML_PAD
(
8192
+
dst_rows
*
4
/*sizeof(ushort2)*/
,
16
)
atIndex
:
0
];
id
<
MTLComputePipelineState
>
pipeline
=
nil
;
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32
].
pipeline
;
[
encoder
setComputePipelineState
:
pipeline
];
[
encoder
setBytes
:
&
args
length
:
sizeof
(
args
)
atIndex
:
0
];
[
encoder
setBuffer
:
h_dst
offset
:
0
atIndex
:
1
];
[
encoder
setBuffer
:
h_ids
offset
:
0
atIndex
:
2
];
[
encoder
setBuffer
:
id_dst
offset
:
offs_dst
atIndex
:
3
];
[
encoder
dispatchThreadgroups
:
MTLSizeMake
((
ne21
+
31
)
/
32
,
(
ne01
+
63
)
/
64
,
n_as
)
threadsPerThreadgroup
:
MTLSizeMake
(
128
,
1
,
1
)];
[
encoder
dispatchThreadgroups
:
MTLSizeMake
(
ne20
,
ne21
,
1
)
threadsPerThreadgroup
:
MTLSizeMake
(
nth
,
1
,
1
)];
}
}
else
{
id
<
MTLComputePipelineState
>
pipeline
=
nil
;
...
...
@@ -3514,7 +3624,7 @@ static bool ggml_metal_encode_node(
[
encoder
setBuffer
:
id_src2
offset
:
offs_src2
atIndex
:
4
];
const
int64_t
_ne1
=
1
;
const
int64_t
ne123
=
dst_rows
;
const
int64_t
ne123
=
ne20
*
ne21
;
if
(
smem
>
0
)
{
[
encoder
setThreadgroupMemoryLength
:
smem
atIndex
:
0
];
...
...
@@ -3718,6 +3828,7 @@ static bool ggml_metal_encode_node(
}
break
;
case
GGML_OP_ROPE
:
{
// make sure we have one or more position id(ne10) per token(ne02)
GGML_ASSERT
(
ne10
%
ne02
==
0
);
GGML_ASSERT
(
ne10
>=
ne02
);
...
...
@@ -3745,19 +3856,41 @@ static bool ggml_metal_encode_node(
memcpy
(
&
beta_slow
,
(
const
int32_t
*
)
dst
->
op_params
+
10
,
sizeof
(
float
));
const
bool
is_neox
=
mode
&
GGML_ROPE_TYPE_NEOX
;
const
bool
is_mrope
=
mode
&
GGML_ROPE_TYPE_MROPE
;
const
bool
is_vision
=
mode
==
GGML_ROPE_TYPE_VISION
;
// mrope
const
int
sect_0
=
((
const
int32_t
*
)
dst
->
op_params
)[
11
];
const
int
sect_1
=
((
const
int32_t
*
)
dst
->
op_params
)[
12
];
const
int
sect_2
=
((
const
int32_t
*
)
dst
->
op_params
)[
13
];
const
int
sect_3
=
((
const
int32_t
*
)
dst
->
op_params
)[
14
];
id
<
MTLComputePipelineState
>
pipeline
=
nil
;
if
(
!
is_neox
)
{
if
(
is_neox
)
{
switch
(
src0
->
type
)
{
case
GGML_TYPE_F32
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32
].
pipeline
;
break
;
case
GGML_TYPE_F16
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16
].
pipeline
;
break
;
case
GGML_TYPE_F32
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32
].
pipeline
;
break
;
case
GGML_TYPE_F16
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16
].
pipeline
;
break
;
default:
GGML_ABORT
(
"fatal error"
);
};
}
else
if
(
is_mrope
&&
!
is_vision
)
{
GGML_ASSERT
(
ne10
*
4
>=
ne02
);
// need at least 4 pos per token
switch
(
src0
->
type
)
{
case
GGML_TYPE_F32
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32
].
pipeline
;
break
;
case
GGML_TYPE_F16
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16
].
pipeline
;
break
;
default:
GGML_ABORT
(
"fatal error"
);
};
}
else
if
(
is_vision
)
{
GGML_ASSERT
(
ne10
*
4
>=
ne02
);
// need at least 4 pos per token
switch
(
src0
->
type
)
{
case
GGML_TYPE_F32
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32
].
pipeline
;
break
;
case
GGML_TYPE_F16
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16
].
pipeline
;
break
;
default:
GGML_ABORT
(
"fatal error"
);
};
}
else
{
switch
(
src0
->
type
)
{
case
GGML_TYPE_F32
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_ROPE_N
EOX
_F32
].
pipeline
;
break
;
case
GGML_TYPE_F16
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_ROPE_N
EOX
_F16
].
pipeline
;
break
;
case
GGML_TYPE_F32
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_ROPE_N
ORM
_F32
].
pipeline
;
break
;
case
GGML_TYPE_F16
:
pipeline
=
ctx
->
kernels
[
GGML_METAL_KERNEL_TYPE_ROPE_N
ORM
_F16
].
pipeline
;
break
;
default:
GGML_ABORT
(
"fatal error"
);
};
}
...
...
@@ -3788,6 +3921,10 @@ static bool ggml_metal_encode_node(
/*.attn_factor =*/
attn_factor
,
/*.beta_fast =*/
beta_fast
,
/*.beta_slow =*/
beta_slow
,
/* sect_0 =*/
sect_0
,
/* sect_1 =*/
sect_1
,
/* sect_2 =*/
sect_2
,
/* sect_3 =*/
sect_3
,
};
[
encoder
setComputePipelineState
:
pipeline
];
...
...
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
View file @
0cefd46f
...
...
@@ -2713,8 +2713,148 @@ kernel void kernel_rope_neox(
}
}
template<typename T>
kernel void kernel_rope_multi(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < args.n_dims) {
const int ic = i0/2;
// mrope theta calculations
// note: the rest is the same as kernel_rope_neox
const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3;
const int sec_w01 = args.sect_0 + args.sect_1; // end of section 1
const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2
const int sector = ic % sect_dims;
float theta_base;
if (sector < args.sect_0) {
theta_base = (float) pos[i2];
} else if (sector < sec_w01) {
theta_base = (float) pos[i2 + args.ne02];
} else if (sector < sec_w012) {
theta_base = (float) pos[i2 + args.ne02 * 2];
} else {
theta_base = (float) pos[i2 + args.ne02 * 3];
}
// end of mrope
const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims/2];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
template<typename T>
kernel void kernel_rope_vision(
constant ggml_metal_kargs_rope & args,
device const char * src0,
device const char * src1,
device const char * src2,
device char * dst,
ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tptg [[threads_per_threadgroup]],
uint3 tgpig[[threadgroup_position_in_grid]]) {
const int i3 = tgpig[2];
const int i2 = tgpig[1];
const int i1 = tgpig[0];
float corr_dims[2];
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
device const int32_t * pos = (device const int32_t *) src1;
const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
if (i0 < 2*args.n_dims) { // different from kernel_rope_multi
const int ic = i0/2;
// mrope theta calculations (only support 2 dimensions)
const int sect_dims = args.sect_0 + args.sect_1;
const int sector = ic % sect_dims;
float p;
float theta_base;
if (sector < args.sect_1) {
p = (float) sector;
theta_base = (float) pos[i2];
} else {
p = (float) sector - args.sect_0;
theta_base = (float) pos[i2 + args.ne02];
}
const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p);
// end of mrope
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
const float x1 = src[args.n_dims]; // different from kernel_rope_multi
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi
} else {
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
}
}
}
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
...
...
@@ -2722,6 +2862,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
typedef void (im2col_t)(
device const float * x,
device char * dst,
...
...
@@ -6381,127 +6527,219 @@ kernel void kernel_mul_mm(
}
}
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
// TODO: this kernel needs to be reimplemented from scratch for better performance
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
void kernel_mul_mm_id_impl(
int32_t ne00,
int32_t ne02,
uint64_t nb01,
uint64_t nb02,
int32_t ne11,
int32_t ne12,
uint64_t nb10,
uint64_t nb11,
uint64_t nb12,
int32_t ne0,
int32_t ne1,
int64_t ne0ne1,
template<typename T4>
kernel void kernel_mul_mm_id_map0(
constant ggml_metal_kargs_mul_mm_id_map0 & args,
device const char * src1,
device const char * src2,
device char * hsrc1,
device char * htpe,
device char * hids,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int ide = tgpig[0]; // expert id
int n_all = 0;
device int32_t * ids_i32 = (device int32_t *) (hids);
for (int i21 = 0; i21 < args.neh11; i21++) { // n_tokens
device const int32_t * src2_i32 = (device const int32_t *) (src2 + i21*args.nb21);
for (int i20 = 0; i20 < args.ne20; i20++) { // n_expert_used
if (src2_i32[i20] != ide) {
continue;
}
device const float4 * src1_f32x4 = (device const float4 *) ( src1 + i21*args.nb12 + (i20%args.ne11)*args.nb11);
device T4 * hsrc1_f32x4 = (device T4 *) (hsrc1 + (ide*args.neh11 + n_all)*args.nbh11);
for (int64_t i00 = tpitg.x; i00 < args.ne10/4; i00 += ntg.x) {
hsrc1_f32x4[i00] = (T4) (src1_f32x4[i00]);
}
if (tpitg.x == 0) {
ids_i32[i21*args.ne20 + i20] = ide*args.neh11 + n_all;
}
++n_all;
}
}
if (tpitg.x == 0) {
device int32_t * tpe_i32 = (device int32_t *) (htpe);
tpe_i32[ide] = n_all;
}
}
typedef decltype(kernel_mul_mm_id_map0<half4>) kernel_mul_mm_id_map0_t;
template [[host_name("kernel_mul_mm_id_map0_f16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<half4>;
template<typename T>
kernel void kernel_mul_mm_id_map1(
constant ggml_metal_kargs_mul_mm_id_map1 & args,
device const char * hdst,
device const char * hids,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i20 = tgpig[0]; // used expert
const int i21 = tgpig[1]; // token
device const int32_t * ids_i32 = (device const int32_t *) (hids);
device float4 * dst_f32x4 = (device float4 *) (dst + i20*args.nb1 + i21*args.nb2);
const int id = ids_i32[i21*args.ne20 + i20];
const int ide = id / args.neh1;
const int idt = id % args.neh1;
device const float4 * hdst_f32x4 = (device const float4 *) (hdst + idt*args.nbh1 + ide*args.nbh2);
for (int64_t i0 = tpitg.x; i0 < args.neh0/4; i0 += ntg.x) {
dst_f32x4[i0] = hdst_f32x4[i0];
}
}
typedef decltype(kernel_mul_mm_id_map1<float>) kernel_mul_mm_id_map1_t;
template [[host_name("kernel_mul_mm_id_map1_f32")]] kernel kernel_mul_mm_id_map1_t kernel_mul_mm_id_map1<float>;
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0,
device const char * src1,
threadgroup ushort2 * rowids
,
device const char * tpe
,
device char * dst,
threadgroup
char * shmem,
threadgroup char * shmem
[[threadgroup(0)]]
,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
threadgroup
half
* sa = (threadgroup
half
*)(shmem);
threadgroup
float
* sb = (threadgroup
float
*)(shmem + 4096);
threadgroup
T
* sa = (threadgroup
T
*)(shmem);
threadgroup
half
* sb = (threadgroup
half
*)(shmem + 4096);
const int r0 = tgpig.y;
const int r1 = tgpig.x;
const int im = tgpig.z;
device const int32_t * tpe_i32 = (device const int32_t *) (tpe);
if (r1*BLOCK_SIZE_N >= ne1) return;
const int neh1 = tpe_i32[im];
if (r1*BLOCK_SIZE_N >= neh1) {
return;
}
// if this block is of 64x32 shape or smaller
short n_rows = (ne0 - r0
*
BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0
*
BLOCK_SIZE_M) : BLOCK_SIZE_M;
short n_cols = (ne1 - r1
*
BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1
*
BLOCK_SIZE_N) : BLOCK_SIZE_N;
const
short n_rows = (
args.
ne
h
0 - r0
*
BLOCK_SIZE_M < BLOCK_SIZE_M) ? (
args.
ne
h
0 - r0
*
BLOCK_SIZE_M) : BLOCK_SIZE_M;
const
short n_cols = (
ne
h
1 - r1
*
BLOCK_SIZE_N < BLOCK_SIZE_N) ? (
ne
h
1 - r1
*
BLOCK_SIZE_N) : BLOCK_SIZE_N;
// a thread shouldn't load data outside of the matrix
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
const
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
const
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
simdgroup_
half8x8
ma[4];
simdgroup_
float
8x8 mb[2];
simdgroup_
T8x8
ma[4];
simdgroup_
half
8x8
mb[2];
simdgroup_float8x8 mc[8];
for (int i = 0; i < 8; i++){
for (short i = 0; i < 8; i++){
mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
short il = (tiitg % THREAD_PER_ROW);
ushort offset1 = il/nl;
const int i12 = im%args.neh12;
const int i13 = im/args.neh12;
const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const short offset1 = il/nl;
threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
device const block_q * x = (device const block_q *)(src0
+ args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
device const
block_q
*
x
= (device const
block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
device const float * y = (device const float *)(src1
+ nb12
* id[1]
+ nb11
* (id[0] % ne11
)
+ nb10
*
(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
device const
half
*
y
= (device const
half *)(src1
+ args.nbh13*i13
+
args.
nb
h
12
*i12
+
args.
nb
h
11
*(r1*BLOCK_SIZE_N + thread_col
)
+
args.
nb
h
10
*
(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
for (int loop_k = 0; loop_k <
args.
ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
half
4x4 temp_a;
T
4x4 temp_a;
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+ (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
#pragma unroll(16)
for (short i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
+ (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
+ (tiitg/THREAD_PER_ROW)%8 + (i&7)*8) = temp_a[i/4][i%4];
}
*(threadgroup
float
2x4 *)(sb + (tiitg
%
THREAD_PER_COL)
* 8 * 32 + 8 *
(tiitg
/
THREAD_PER_COL)) = *((device
float
2x4 *)y);
*(threadgroup
half
2x4 *)(sb +
32*8*
(tiitg
%
THREAD_PER_COL)
+ 8*
(tiitg
/
THREAD_PER_COL)) = *((device
half
2x4 *)
y);
il = (il + 2 < nl) ? il + 2 : il % 2;
x = (il < 2) ? x + (2
+nl-
1)/nl : x;
x = (il < 2) ? x + (2
+ nl -
1)/nl : x;
y += BLOCK_SIZE_K;
threadgroup_barrier(mem_flags::mem_threadgroup);
// load matrices from threadgroup memory and conduct outer products
threadgroup
half
* lsma = (sa + THREAD_MAT_M
*
SG_MAT_SIZE
*
(sgitg
%
2));
threadgroup
float
* lsmb = (sb + THREAD_MAT_N
*
SG_MAT_SIZE
*
(sgitg
/
2));
threadgroup
const T
* lsma = (sa + THREAD_MAT_M
*
SG_MAT_SIZE
*
(sgitg
%
2));
threadgroup
const half
* lsmb = (sb + THREAD_MAT_N
*
SG_MAT_SIZE
*
(sgitg
/
2));
#pragma unroll(BLOCK_SIZE_K/8)
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
#pragma unroll(4)
for (int i = 0; i < 4; i++) {
for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) {
#pragma unroll(4)
for (short i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
}
simdgroup_barrier(mem_flags::mem_none);
#pragma unroll(2)
for (
in
t i = 0; i < 2; i++) {
for (
shor
t i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
}
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
#pragma unroll(8)
for (
in
t i = 0; i < 8; i++){
for (
shor
t i = 0; i < 8; i++){
simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE;
lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE;
}
}
{
if ((r0 + 1) * BLOCK_SIZE_M <= args.neh0 && (r1 + 1) * BLOCK_SIZE_N <= neh1) {
device float * C = (device float *) dst +
(BLOCK_SIZE_M * r0 + 32*(sgitg & 1)) + \
(BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.neh0 + im*args.neh1*args.neh0;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.neh0 * (i/4), args.neh0);
}
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ 32
*
(sgitg&1) + (16
*
(sgitg>>1))
*
BLOCK_SIZE_M;
for (
in
t i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8
*
(i%4) + 8
*
BLOCK_SIZE_M
*
(i/4), BLOCK_SIZE_M);
+ 32
*
(sgitg&1) + (16
*
(sgitg
>>
1))
*
BLOCK_SIZE_M;
for (
shor
t i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8
*
(i%4) + 8
*
BLOCK_SIZE_M
*
(i/4), BLOCK_SIZE_M);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.neh0 + im*args.neh1*args.neh0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
...
...
@@ -6521,66 +6759,6 @@ void kernel_mul_mm_id_impl(
}
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm_id(
constant ggml_metal_kargs_mul_mm_id & args,
device const char * src0s,
device const char * src1,
device char * dst,
device const char * ids,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiitg[[thread_index_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const int32_t i02 = tgpig.z;
tgpig.z = 0;
device const char * src0 = src0s + i02*args.nb02;
// row indices
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
// TODO: parallelize this loop
int32_t _ne1 = 0;
for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
if (id == i02) {
if (tiitg == 0) {
rowids[_ne1] = ushort2(ii0, ii1);
}
_ne1++;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
args.ne00,
args.ne02,
args.nb01,
args.nb02,
args.ne11,
args.ne12,
args.nb10,
args.nb11,
args.nb12,
args.ne0,
_ne1,
(int64_t)args.ne0*args.ne1,
src0,
src1,
rowids,
dst,
shmem,
tgpig,
tiitg,
sgitg);
}
#define QK_NL 16
//
...
...
@@ -6621,63 +6799,64 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get
// matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) m
at
_mm_t;
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) m
ul
_mm_t;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel m
at
_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
template [[host_name("kernel_mul_mm_bf16_f32")]] kernel m
ul
_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16>;
#endif
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel m
at
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel m
ul
_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
//
// indirect matrix-matrix multiplication
//
typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) m
at
_mm_id
_t
;
typedef decltype(kernel_mul_mm_id<
half, half4x4, simdgroup_half8x8,
float4x4, 1, dequantize_f32>) m
ul
_mm_id;
template [[host_name("kernel_mul_mm_id_f32_f
32
")]] kernel m
at
_mm_id
_t
kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_id_f16_f
32
")]] kernel m
at
_mm_id
_t
kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_id_f32_f
16
")]] kernel m
ul
_mm_id kernel_mul_mm_id<
half, half4x4, simdgroup_half8x8,
float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_id_f16_f
16
")]] kernel m
ul
_mm_id kernel_mul_mm_id<
half, half4x4, simdgroup_half8x8,
half4x4, 1, dequantize_f16>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_mul_mm_id_bf16_f
32
")]] kernel m
at
_mm_id
_t
kernel_mul_mm_id<bfloat4x4, 1, dequantize_bf16>;
template [[host_name("kernel_mul_mm_id_bf16_f
16
")]] kernel m
ul
_mm_id kernel_mul_mm_id<
bfloat, bfloat4x4, simdgroup_bfloat8x8,
bfloat4x4, 1, dequantize_bf16>;
#endif
template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_id_q6_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
template [[host_name("kernel_mul_mm_id_iq2_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
template [[host_name("kernel_mul_mm_id_iq2_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
template [[host_name("kernel_mul_mm_id_iq3_xxs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
template [[host_name("kernel_mul_mm_id_iq3_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
template [[host_name("kernel_mul_mm_id_iq2_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
template [[host_name("kernel_mul_mm_id_iq1_s_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
//
// matrix-vector multiplication
...
...
ml/backend/ggml/ggml/src/ggml-opt.cpp
View file @
0cefd46f
...
...
@@ -32,12 +32,15 @@ struct ggml_opt_context {
ggml_cgraph
*
allocated_graph
=
nullptr
;
ggml_cgraph
*
allocated_graph_copy
=
nullptr
;
struct
ggml_context
*
ctx_static
=
nullptr
;
struct
ggml_context
*
ctx_
static_cpu
=
nullptr
;
struct
ggml_context
*
ctx_
cpu
=
nullptr
;
struct
ggml_context
*
ctx_compute
=
nullptr
;
struct
ggml_context
*
ctx_copy
=
nullptr
;
ggml_backend_buffer_t
buf_static
=
nullptr
;
ggml_backend_buffer_t
buf_
static_cpu
=
nullptr
;
ggml_backend_buffer_t
buf_
cpu
=
nullptr
;
std
::
mt19937
rng
;
enum
ggml_opt_loss_type
loss_type
;
enum
ggml_opt_build_type
build_type
;
enum
ggml_opt_build_type
build_type_alloc
;
struct
ggml_tensor
*
inputs
=
nullptr
;
struct
ggml_tensor
*
outputs
=
nullptr
;
...
...
@@ -50,6 +53,11 @@ struct ggml_opt_context {
struct
ggml_cgraph
*
gf
=
nullptr
;
struct
ggml_cgraph
*
gb_grad
=
nullptr
;
struct
ggml_cgraph
*
gb_opt
=
nullptr
;
bool
static_graphs
=
false
;
bool
eval_ready
=
false
;
std
::
vector
<
struct
ggml_tensor
*>
grad_accs
;
std
::
vector
<
struct
ggml_tensor
*>
grad_m
;
std
::
vector
<
struct
ggml_tensor
*>
grad_v
;
int64_t
iter
=
1
;
int32_t
opt_period
=
1
;
...
...
@@ -73,7 +81,13 @@ struct ggml_opt_result {
// ====== Dataset ======
ggml_opt_dataset_t
ggml_opt_dataset_init
(
int64_t
ne_datapoint
,
int64_t
ne_label
,
int64_t
ndata
,
int64_t
ndata_shard
)
{
ggml_opt_dataset_t
ggml_opt_dataset_init
(
enum
ggml_type
type_data
,
enum
ggml_type
type_label
,
int64_t
ne_datapoint
,
int64_t
ne_label
,
int64_t
ndata
,
int64_t
ndata_shard
)
{
GGML_ASSERT
(
ne_datapoint
>
0
);
GGML_ASSERT
(
ne_label
>=
0
);
GGML_ASSERT
(
ndata
>
0
);
...
...
@@ -92,11 +106,11 @@ ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label,
result
->
ctx
=
ggml_init
(
params
);
}
result
->
data
=
ggml_new_tensor_2d
(
result
->
ctx
,
GGML_TYPE_F32
,
ne_datapoint
,
ndata
);
result
->
data
=
ggml_new_tensor_2d
(
result
->
ctx
,
type_data
,
ne_datapoint
,
ndata
);
result
->
nbs_data
=
ggml_nbytes
(
result
->
data
)
*
ndata_shard
/
ndata
;
if
(
ne_label
>
0
)
{
result
->
labels
=
ggml_new_tensor_2d
(
result
->
ctx
,
GGML_TYPE_F32
,
ne_label
,
ndata
);
result
->
labels
=
ggml_new_tensor_2d
(
result
->
ctx
,
type_label
,
ne_label
,
ndata
);
result
->
nbs_labels
=
ggml_nbytes
(
result
->
labels
)
*
ndata_shard
/
ndata
;
}
else
{
result
->
labels
=
nullptr
;
...
...
@@ -119,6 +133,10 @@ void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) {
delete
dataset
;
}
int64_t
ggml_opt_dataset_ndata
(
ggml_opt_dataset_t
dataset
)
{
return
dataset
->
ndata
;
}
struct
ggml_tensor
*
ggml_opt_dataset_data
(
ggml_opt_dataset_t
dataset
)
{
return
dataset
->
data
;
}
...
...
@@ -144,6 +162,8 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
GGML_ASSERT
(
data_batch
&&
ggml_is_contiguous
(
data_batch
));
GGML_ASSERT
(
!
labels_batch
||
ggml_is_contiguous
(
labels_batch
));
GGML_ASSERT
((
labels_batch
==
nullptr
)
==
(
dataset
->
labels
==
nullptr
));
GGML_ASSERT
(
data_batch
->
type
==
dataset
->
data
->
type
);
GGML_ASSERT
(
!
labels_batch
||
labels_batch
->
type
==
dataset
->
labels
->
type
);
const
size_t
nb_data_batch
=
ggml_nbytes
(
data_batch
);
GGML_ASSERT
(
nb_data_batch
%
dataset
->
nbs_data
==
0
);
...
...
@@ -171,6 +191,31 @@ void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor *
}
}
void
ggml_opt_dataset_get_batch_host
(
ggml_opt_dataset_t
dataset
,
void
*
data_batch
,
size_t
nb_data_batch
,
void
*
labels_batch
,
int64_t
ibatch
)
{
GGML_ASSERT
((
labels_batch
==
nullptr
)
==
(
dataset
->
labels
==
nullptr
));
GGML_ASSERT
(
nb_data_batch
%
dataset
->
nbs_data
==
0
);
const
int64_t
shards_per_batch
=
nb_data_batch
/
dataset
->
nbs_data
;
GGML_ASSERT
((
ibatch
+
1
)
*
shards_per_batch
<=
int64_t
(
dataset
->
permutation
.
size
()));
for
(
int64_t
ishard_batch
=
0
;
ishard_batch
<
shards_per_batch
;
++
ishard_batch
)
{
const
int64_t
ishard
=
dataset
->
permutation
[
ibatch
*
shards_per_batch
+
ishard_batch
];
const
char
*
ptr_data
=
(
const
char
*
)
dataset
->
data
->
data
+
ishard
*
dataset
->
nbs_data
;
char
*
ptr_data_batch
=
(
char
*
)
data_batch
+
ishard_batch
*
dataset
->
nbs_data
;
memcpy
(
ptr_data_batch
,
ptr_data
,
dataset
->
nbs_data
);
if
(
!
labels_batch
)
{
continue
;
}
const
char
*
ptr_labels
=
(
const
char
*
)
dataset
->
labels
->
data
+
ishard
*
dataset
->
nbs_labels
;
char
*
ptr_labels_batch
=
(
char
*
)
labels_batch
+
ishard_batch
*
dataset
->
nbs_labels
;
memcpy
(
ptr_labels_batch
,
ptr_labels
,
dataset
->
nbs_labels
);
}
}
// ====== Model / Context ======
struct
ggml_opt_optimizer_params
ggml_opt_get_default_optimizer_params
(
void
*
userdata
)
{
...
...
@@ -187,17 +232,18 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
return
result
;
}
struct
ggml_opt_optimizer_params
ggml_opt_get_constant_optimizer_params
(
void
*
userdata
)
{
return
*
((
struct
ggml_opt_optimizer_params
*
)
userdata
);
}
struct
ggml_opt_params
ggml_opt_default_params
(
ggml_backend_sched_t
backend_sched
,
struct
ggml_context
*
ctx_compute
,
struct
ggml_tensor
*
inputs
,
struct
ggml_tensor
*
outputs
,
enum
ggml_opt_loss_type
loss_type
)
{
return
{
/*backend_sched =*/
backend_sched
,
/*ctx_compute =*/
ctx_compute
,
/*inputs =*/
inputs
,
/*logits =*/
outputs
,
/*ctx_compute =*/
nullptr
,
/*inputs =*/
nullptr
,
/*logits =*/
nullptr
,
/*loss_type =*/
loss_type
,
/*build_type =*/
GGML_OPT_BUILD_TYPE_OPT
,
/*opt_period =*/
1
,
...
...
@@ -266,195 +312,246 @@ static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * src) {
return
dst
;
}
static
void
ggml_opt_alloc_graph
(
ggml_opt_context_t
opt_ctx
,
ggml_cgraph
*
graph
)
{
GGML_ASSERT
(
graph
);
if
(
opt_ctx
->
allocated_graph
==
graph
)
{
return
;
}
ggml_backend_sched_reset
(
opt_ctx
->
backend_sched
);
// clear allocation of previous graph
{
ggml_init_params
params
=
{
/*.mem_size =*/
ggml_tensor_overhead
()
*
GGML_DEFAULT_GRAPH_SIZE
,
/*.mem_buffer =*/
nullptr
,
/*.no_alloc =*/
true
,
};
ggml_free
(
opt_ctx
->
ctx_copy
);
opt_ctx
->
ctx_copy
=
ggml_init
(
params
);
}
opt_ctx
->
allocated_graph_copy
=
dup_graph
(
opt_ctx
->
ctx_copy
,
graph
);
ggml_backend_sched_alloc_graph
(
opt_ctx
->
backend_sched
,
opt_ctx
->
allocated_graph_copy
);
opt_ctx
->
allocated_graph
=
graph
;
}
ggml_opt_context_t
ggml_opt_init
(
struct
ggml_opt_params
params
)
{
ggml_opt_context_t
result
=
new
struct
ggml_opt_context
;
result
->
backend_sched
=
params
.
backend_sched
;
result
->
ctx_compute
=
params
.
ctx_compute
;
result
->
inputs
=
params
.
inputs
;
result
->
outputs
=
params
.
outputs
;
result
->
opt_period
=
params
.
opt_period
;
result
->
get_opt_pars
=
params
.
get_opt_pars
;
result
->
get_opt_pars_ud
=
params
.
get_opt_pars_ud
;
GGML_ASSERT
(
result
->
inputs
->
data
&&
"the inputs must be allocated statically"
);
GGML_ASSERT
(
result
->
opt_period
>=
1
);
const
bool
accumulate
=
params
.
build_type
==
GGML_OPT_BUILD_TYPE_GRAD
||
(
params
.
build_type
==
GGML_OPT_BUILD_TYPE_OPT
&&
result
->
opt_period
>
1
);
static
void
ggml_opt_build
(
ggml_opt_context_t
opt_ctx
)
{
GGML_ASSERT
(
opt_ctx
->
ctx_compute
&&
"no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc"
);
GGML_ASSERT
((
!
opt_ctx
->
static_graphs
||
opt_ctx
->
inputs
->
data
)
&&
"when using static graphs the inputs must be allocated statically"
);
ggml_set_input
(
result
->
inputs
);
ggml_set_output
(
result
->
outputs
);
const
bool
accumulate
=
opt_ctx
->
build_type_alloc
>=
GGML_OPT_BUILD_TYPE_GRAD
&&
!
(
opt_ctx
->
static_graphs
&&
opt_ctx
->
build_type_alloc
==
GGML_OPT_BUILD_TYPE_OPT
&&
opt_ctx
->
opt_period
==
1
);
result
->
gf
=
ggml_new_graph_custom
(
result
->
ctx_compute
,
GGML_DEFAULT_GRAPH_SIZE
,
/*grads =*/
true
);
// Forward pass.
ggml_
build_forward_expand
(
result
->
gf
,
result
->
outputs
);
ggml_set_input
(
opt_ctx
->
inputs
);
ggml_
set_output
(
opt_ctx
->
outputs
);
int
n_param
=
0
;
for
(
int
i
=
0
;
i
<
result
->
gf
->
n_nodes
;
++
i
)
{
if
(
result
->
gf
->
nodes
[
i
]
->
flags
&
GGML_TENSOR_FLAG_PARAM
)
{
for
(
int
i
=
0
;
i
<
opt_ctx
->
gf
->
n_nodes
;
++
i
)
{
const
struct
ggml_tensor
*
node
=
opt_ctx
->
gf
->
nodes
[
i
];
if
(
node
->
flags
&
GGML_TENSOR_FLAG_PARAM
)
{
n_param
++
;
}
GGML_ASSERT
(
!
(
node
->
flags
&
GGML_TENSOR_FLAG_LOSS
)
&&
"support for extra loss terms not implemented"
);
}
{
if
(
!
opt_ctx
->
ctx_static
)
{
// The static context is used for:
// - gradients (1 tensor per param if using gradient accumulation)
// - gradients (1
per loss, 1
tensor per param if using gradient accumulation)
// - optimizer momenta (2 tensors per param)
// - labels
// - loss + its gradient (up to 5 tensors)
// - pred
// - ncorrect (2 tensors).
const
size_t
tensors_per_param
=
(
accumulate
?
1
:
0
)
+
(
params
.
build_type
==
GGML_OPT_BUILD_TYPE_OPT
?
2
:
0
);
const
size_t
size_meta
=
(
tensors_per_param
*
n_param
+
9
)
*
ggml_tensor_overhead
();
// - labels (if using static graphs)
// - loss (if using static graphs, up to 5 tensors)
// - pred (if using static graphs)
// - ncorrect (if using static graphs, 2 tensors).
constexpr
size_t
n_loss
=
1
;
const
size_t
tensors_per_param
=
(
accumulate
?
1
:
0
)
+
(
opt_ctx
->
build_type_alloc
==
GGML_OPT_BUILD_TYPE_OPT
?
2
:
0
);
const
size_t
tensors_const
=
opt_ctx
->
static_graphs
?
9
:
0
;
const
size_t
size_meta
=
(
n_loss
+
tensors_per_param
*
n_param
+
tensors_const
)
*
ggml_tensor_overhead
();
struct
ggml_init_params
params
=
{
/*.mem_size =*/
size_meta
,
/*.mem_buffer =*/
nullptr
,
/*.no_alloc =*/
true
,
};
result
->
ctx_static
=
ggml_init
(
params
);
opt_ctx
->
ctx_static
=
ggml_init
(
params
);
}
GGML_ASSERT
(
opt_ctx
->
build_type
<=
opt_ctx
->
build_type_alloc
);
{
// The static cpu context is used for:
// - optimizer parameters (1 for the entire context)
// The cpu context is allocated statically if using static graphs, dynamically otherwise.
// It is used for:
// - optimizer parameters (1 shared for all optimizer invocations)
const
size_t
size_meta
=
1
*
ggml_tensor_overhead
();
struct
ggml_init_params
params
=
{
/*.mem_size =*/
size_meta
,
/*.mem_buffer =*/
nullptr
,
/*.no_alloc =*/
true
,
};
result
->
ctx_static_cpu
=
ggml_init
(
params
);
ggml_free
(
opt_ctx
->
ctx_cpu
);
opt_ctx
->
ctx_cpu
=
ggml_init
(
params
);
ggml_backend_buffer_free
(
opt_ctx
->
buf_cpu
);
opt_ctx
->
buf_cpu
=
nullptr
;
}
struct
ggml_context
*
ctx_results
=
opt_ctx
->
static_graphs
?
opt_ctx
->
ctx_static
:
opt_ctx
->
ctx_compute
;
switch
(
params
.
loss_type
)
{
switch
(
opt_ctx
->
loss_type
)
{
case
GGML_OPT_LOSS_TYPE_MEAN
:
{
result
->
loss
=
ggml_sum
(
result
->
ctx_static
,
result
->
outputs
);
ggml_set_name
(
result
->
loss
,
"loss_sum"
);
const
float
scale
=
1.0
f
/
(
result
->
opt_period
*
ggml_nelements
(
result
->
outputs
));
result
->
loss
=
ggml_scale
(
result
->
ctx_static
,
result
->
loss
,
scale
);
ggml_set_name
(
result
->
loss
,
"loss_mean"
);
result
->
loss_per_datapoint
=
true
;
opt_ctx
->
loss
=
ggml_sum
(
ctx_
result
s
,
opt_ctx
->
outputs
);
ggml_set_name
(
opt_ctx
->
loss
,
"loss_sum"
);
const
float
scale
=
1.0
f
/
(
opt_ctx
->
opt_period
*
ggml_nelements
(
opt_ctx
->
outputs
));
opt_ctx
->
loss
=
ggml_scale
(
ctx_
result
s
,
opt_ctx
->
loss
,
scale
);
ggml_set_name
(
opt_ctx
->
loss
,
"loss_mean"
);
opt_ctx
->
loss_per_datapoint
=
true
;
break
;
}
case
GGML_OPT_LOSS_TYPE_SUM
:
{
result
->
loss
=
ggml_sum
(
result
->
ctx_static
,
result
->
outputs
);
ggml_set_name
(
result
->
loss
,
"loss_sum"
);
result
->
loss_per_datapoint
=
false
;
opt_ctx
->
loss
=
ggml_sum
(
ctx_
result
s
,
opt_ctx
->
outputs
);
ggml_set_name
(
opt_ctx
->
loss
,
"loss_sum"
);
opt_ctx
->
loss_per_datapoint
=
false
;
break
;
}
case
GGML_OPT_LOSS_TYPE_CROSS_ENTROPY
:
{
result
->
labels
=
ggml_dup_tensor
(
result
->
ctx_static
,
result
->
outputs
);
ggml_set_input
(
result
->
labels
);
ggml_set_name
(
result
->
labels
,
"labels"
);
result
->
loss
=
ggml_cross_entropy_loss
(
result
->
ctx_static
,
result
->
outputs
,
result
->
labels
);
ggml_set_name
(
result
->
loss
,
"loss_cross_entropy"
);
if
(
result
->
opt_period
>
1
)
{
result
->
loss
=
ggml_scale
(
result
->
ctx_static
,
result
->
loss
,
1.0
f
/
result
->
opt_period
);
ggml_set_name
(
result
->
loss
,
"loss_cross_entropy_scaled"
);
}
result
->
loss_per_datapoint
=
true
;
opt_ctx
->
labels
=
ggml_dup_tensor
(
ctx_
result
s
,
opt_ctx
->
outputs
);
ggml_set_input
(
opt_ctx
->
labels
);
ggml_set_name
(
opt_ctx
->
labels
,
"labels"
);
opt_ctx
->
loss
=
ggml_cross_entropy_loss
(
ctx_results
,
opt_ctx
->
outputs
,
opt_ctx
->
labels
);
ggml_set_name
(
opt_ctx
->
loss
,
"loss_cross_entropy"
);
if
(
opt_ctx
->
opt_period
>
1
)
{
opt_ctx
->
loss
=
ggml_scale
(
ctx_results
,
opt_ctx
->
loss
,
1.0
f
/
opt_ctx
->
opt_period
);
ggml_set_name
(
opt_ctx
->
loss
,
"loss_cross_entropy_scaled"
);
}
opt_ctx
->
loss_per_datapoint
=
true
;
break
;
}
case
GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR
:
{
result
->
labels
=
ggml_dup_tensor
(
result
->
ctx_static
,
result
->
outputs
);
ggml_set_input
(
result
->
labels
);
ggml_set_name
(
result
->
labels
,
"labels"
);
result
->
loss
=
ggml_sub
(
result
->
ctx_static
,
result
->
outputs
,
result
->
labels
);
ggml_set_name
(
result
->
loss
,
"loss_error"
);
result
->
loss
=
ggml_sqr
(
result
->
ctx_static
,
result
->
loss
);
ggml_set_name
(
result
->
loss
,
"loss_squared_error"
);
result
->
loss
=
ggml_sum
(
result
->
ctx_static
,
result
->
loss
);
ggml_set_name
(
result
->
loss
,
"loss_sum_squared_error"
);
const
float
scale
=
1.0
f
/
(
result
->
opt_period
*
ggml_nelements
(
result
->
outputs
));
result
->
loss
=
ggml_scale
(
result
->
ctx_static
,
result
->
loss
,
scale
);
ggml_set_name
(
result
->
loss
,
"loss_mean_squared_error"
);
result
->
loss_per_datapoint
=
true
;
opt_ctx
->
labels
=
ggml_dup_tensor
(
ctx_
result
s
,
opt_ctx
->
outputs
);
ggml_set_input
(
opt_ctx
->
labels
);
ggml_set_name
(
opt_ctx
->
labels
,
"labels"
);
opt_ctx
->
loss
=
ggml_sub
(
ctx_results
,
opt_ctx
->
outputs
,
opt_ctx
->
labels
);
ggml_set_name
(
opt_ctx
->
loss
,
"loss_error"
);
opt_ctx
->
loss
=
ggml_sqr
(
ctx_
result
s
,
opt_ctx
->
loss
);
ggml_set_name
(
opt_ctx
->
loss
,
"loss_squared_error"
);
opt_ctx
->
loss
=
ggml_sum
(
ctx_
result
s
,
opt_ctx
->
loss
);
ggml_set_name
(
opt_ctx
->
loss
,
"loss_sum_squared_error"
);
const
float
scale
=
1.0
f
/
(
opt_ctx
->
opt_period
*
ggml_nelements
(
opt_ctx
->
outputs
));
opt_ctx
->
loss
=
ggml_scale
(
ctx_
result
s
,
opt_ctx
->
loss
,
scale
);
ggml_set_name
(
opt_ctx
->
loss
,
"loss_mean_squared_error"
);
opt_ctx
->
loss_per_datapoint
=
true
;
break
;
}
}
ggml_set_output
(
result
->
loss
);
ggml_set_loss
(
result
->
loss
);
ggml_build_forward_expand
(
result
->
gf
,
result
->
loss
);
ggml_set_output
(
opt_ctx
->
loss
);
ggml_set_loss
(
opt_ctx
->
loss
);
ggml_build_forward_expand
(
opt_ctx
->
gf
,
opt_ctx
->
loss
);
if
(
opt_ctx
->
loss_type
==
GGML_OPT_LOSS_TYPE_CROSS_ENTROPY
)
{
opt_ctx
->
pred
=
ggml_argmax
(
ctx_results
,
opt_ctx
->
outputs
);
ggml_set_name
(
opt_ctx
->
pred
,
"pred"
);
ggml_set_output
(
opt_ctx
->
pred
);
ggml_build_forward_expand
(
opt_ctx
->
gf
,
opt_ctx
->
pred
);
opt_ctx
->
ncorrect
=
ggml_count_equal
(
ctx_results
,
opt_ctx
->
pred
,
ggml_argmax
(
ctx_results
,
opt_ctx
->
labels
));
ggml_set_name
(
opt_ctx
->
ncorrect
,
"ncorrect"
);
ggml_set_output
(
opt_ctx
->
ncorrect
);
ggml_build_forward_expand
(
opt_ctx
->
gf
,
opt_ctx
->
ncorrect
);
}
if
(
opt_ctx
->
buf_static
)
{
if
(
opt_ctx
->
build_type
==
GGML_OPT_BUILD_TYPE_FORWARD
)
{
return
;
}
}
else
if
(
opt_ctx
->
build_type_alloc
==
GGML_OPT_BUILD_TYPE_FORWARD
)
{
opt_ctx
->
buf_static
=
ggml_backend_alloc_ctx_tensors
(
opt_ctx
->
ctx_static
,
ggml_backend_sched_get_backend
(
opt_ctx
->
backend_sched
,
0
));
return
;
}
result
->
pred
=
ggml_argmax
(
result
->
ctx_static
,
result
->
outputs
);
ggml_set_name
(
result
->
pred
,
"pred"
);
ggml_set_output
(
result
->
pred
);
ggml_build_forward_expand
(
result
->
gf
,
result
->
pred
);
if
(
opt_ctx
->
grad_accs
.
empty
())
{
GGML_ASSERT
(
opt_ctx
->
build_type_alloc
>=
GGML_OPT_BUILD_TYPE_GRAD
);
if
(
result
->
labels
)
{
result
->
ncorrect
=
ggml_count_equal
(
result
->
ctx_static
,
result
->
pred
,
ggml_argmax
(
result
->
ctx_static
,
result
->
labels
));
ggml_set_name
(
result
->
ncorrect
,
"ncorrect"
);
ggml_set_output
(
result
->
ncorrect
);
ggml_build_forward_expand
(
result
->
gf
,
result
->
ncorrect
);
const
int
n_nodes
=
opt_ctx
->
gf
->
n_nodes
;
opt_ctx
->
grad_accs
.
resize
(
n_nodes
);
for
(
int
i
=
0
;
i
<
n_nodes
;
++
i
)
{
ggml_tensor
*
node
=
opt_ctx
->
gf
->
nodes
[
i
];
if
((
accumulate
&&
(
node
->
flags
&
GGML_TENSOR_FLAG_PARAM
))
||
(
node
->
flags
&
GGML_TENSOR_FLAG_LOSS
))
{
opt_ctx
->
grad_accs
[
i
]
=
ggml_new_tensor
(
opt_ctx
->
ctx_static
,
GGML_TYPE_F32
,
GGML_MAX_DIMS
,
node
->
ne
);
}
else
{
result
->
ncorrect
=
nullptr
;
opt_ctx
->
grad_accs
[
i
]
=
nullptr
;
}
}
if
(
params
.
build_type
==
GGML_OPT_BUILD_TYPE_FORWARD
)
{
result
->
buf_static
=
ggml_backend_alloc_ctx_tensors
(
result
->
ctx_static
,
ggml_backend_sched_get_backend
(
result
->
backend_sched
,
0
));
return
result
;
if
(
opt_ctx
->
build_type_alloc
>=
GGML_OPT_BUILD_TYPE_OPT
)
{
opt_ctx
->
grad_m
.
resize
(
n_nodes
);
opt_ctx
->
grad_v
.
resize
(
n_nodes
);
for
(
int
i
=
0
;
i
<
n_nodes
;
++
i
)
{
ggml_tensor
*
node
=
opt_ctx
->
gf
->
nodes
[
i
];
if
(
node
->
flags
&
GGML_TENSOR_FLAG_PARAM
)
{
opt_ctx
->
grad_m
[
i
]
=
ggml_new_tensor
(
opt_ctx
->
ctx_static
,
GGML_TYPE_F32
,
GGML_MAX_DIMS
,
node
->
ne
);
opt_ctx
->
grad_v
[
i
]
=
ggml_new_tensor
(
opt_ctx
->
ctx_static
,
GGML_TYPE_F32
,
GGML_MAX_DIMS
,
node
->
ne
);
}
else
{
opt_ctx
->
grad_m
[
i
]
=
nullptr
;
opt_ctx
->
grad_v
[
i
]
=
nullptr
;
}
}
}
}
// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
result
->
gb_grad
=
ggml_graph_dup
(
result
->
ctx_compute
,
result
->
gf
);
ggml_build_backward_expand
(
result
->
ctx_static
,
result
->
ctx_compute
,
result
->
gb_grad
,
accumulate
);
opt_ctx
->
gb_grad
=
ggml_graph_dup
(
opt_ctx
->
ctx_compute
,
opt_ctx
->
gf
,
/*force_grads =*/
true
);
ggml_build_backward_expand
(
opt_ctx
->
ctx_compute
,
opt_ctx
->
gb_grad
,
opt_ctx
->
grad_accs
.
data
()
);
if
(
params
.
build_type
==
GGML_OPT_BUILD_TYPE_GRAD
)
{
result
->
buf_static
=
ggml_backend_alloc_ctx_tensors
(
result
->
ctx_static
,
ggml_backend_sched_get_backend
(
result
->
backend_sched
,
0
));
ggml_graph_reset
(
result
->
gb_grad
);
return
result
;
if
(
opt_ctx
->
buf_static
)
{
if
(
opt_ctx
->
build_type
==
GGML_OPT_BUILD_TYPE_GRAD
)
{
return
;
}
}
else
if
(
opt_ctx
->
build_type_alloc
==
GGML_OPT_BUILD_TYPE_GRAD
)
{
opt_ctx
->
buf_static
=
ggml_backend_alloc_ctx_tensors
(
opt_ctx
->
ctx_static
,
ggml_backend_sched_get_backend
(
opt_ctx
->
backend_sched
,
0
));
ggml_graph_reset
(
opt_ctx
->
gb_grad
);
}
GGML_ASSERT
(
params
.
build_type
==
GGML_OPT_BUILD_TYPE_OPT
);
GGML_ASSERT
(
opt_ctx
->
build_type
_alloc
==
GGML_OPT_BUILD_TYPE_OPT
);
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
result
->
gb_opt
=
ggml_graph_dup
(
result
->
ctx_compute
,
result
->
gb_grad
);
opt_ctx
->
gb_opt
=
ggml_graph_dup
(
opt_ctx
->
ctx_compute
,
opt_ctx
->
gb_grad
,
/*force_grads =*/
true
);
result
->
adamw_params
=
ggml_new_tensor_1d
(
result
->
ctx_static
_cpu
,
GGML_TYPE_F32
,
7
);
ggml_set_input
(
result
->
adamw_params
);
ggml_set_name
(
result
->
adamw_params
,
"adamw_params"
);
opt_ctx
->
adamw_params
=
ggml_new_tensor_1d
(
opt_ctx
->
ctx
_cpu
,
GGML_TYPE_F32
,
7
);
ggml_set_input
(
opt_ctx
->
adamw_params
);
ggml_set_name
(
opt_ctx
->
adamw_params
,
"adamw_params"
);
for
(
int
i
=
result
->
gf
->
n_nodes
-
1
;
i
>=
0
;
--
i
)
{
struct
ggml_tensor
*
node
=
result
->
gb_opt
->
nodes
[
i
];
struct
ggml_tensor
*
grad
=
ggml_graph_get_grad
(
result
->
gb_opt
,
node
);
for
(
int
i
=
opt_ctx
->
gf
->
n_nodes
-
1
;
i
>=
0
;
--
i
)
{
struct
ggml_tensor
*
node
=
opt_ctx
->
gb_opt
->
nodes
[
i
];
struct
ggml_tensor
*
grad
=
ggml_graph_get_grad
(
opt_ctx
->
gb_opt
,
node
);
if
(
node
->
flags
&
GGML_TENSOR_FLAG_PARAM
)
{
struct
ggml_tensor
*
m
=
ggml_dup_tensor
(
result
->
ctx_static
,
node
);
struct
ggml_tensor
*
v
=
ggml_dup_tensor
(
result
->
ctx_static
,
node
);
struct
ggml_tensor
*
opt_step
=
ggml_opt_step_adamw
(
result
->
ctx_compute
,
node
,
grad
,
m
,
v
,
result
->
adamw_params
);
ggml_build_forward_expand
(
result
->
gb_opt
,
opt_step
);
if
(
grad
&&
(
node
->
flags
&
GGML_TENSOR_FLAG_PARAM
))
{
struct
ggml_tensor
*
m
=
opt_ctx
->
grad_m
[
i
];
struct
ggml_tensor
*
v
=
opt_ctx
->
grad_v
[
i
];
struct
ggml_tensor
*
opt_step
=
ggml_opt_step_adamw
(
opt_ctx
->
ctx_compute
,
node
,
grad
,
m
,
v
,
opt_ctx
->
adamw_params
);
ggml_set_name
(
m
,
(
std
::
string
(
"AdamW m for "
)
+
std
::
string
(
node
->
name
)).
c_str
());
ggml_set_name
(
v
,
(
std
::
string
(
"AdamW v for "
)
+
std
::
string
(
node
->
name
)).
c_str
());
ggml_set_name
(
opt_step
,
(
std
::
string
(
"AdamW step for "
)
+
std
::
string
(
node
->
name
)).
c_str
());
ggml_build_forward_expand
(
opt_ctx
->
gb_opt
,
opt_step
);
}
}
result
->
buf_static
=
ggml_backend_alloc_ctx_tensors
(
result
->
ctx_static
,
ggml_backend_sched_get_backend
(
result
->
backend_sched
,
0
));
if
(
!
opt_ctx
->
buf_static
)
{
opt_ctx
->
buf_static
=
ggml_backend_alloc_ctx_tensors
(
opt_ctx
->
ctx_static
,
ggml_backend_sched_get_backend
(
opt_ctx
->
backend_sched
,
0
));
ggml_graph_reset
(
opt_ctx
->
gb_opt
);
}
opt_ctx
->
buf_cpu
=
ggml_backend_alloc_ctx_tensors_from_buft
(
opt_ctx
->
ctx_cpu
,
ggml_backend_cpu_buffer_type
());
}
ggml_opt_context_t
ggml_opt_init
(
struct
ggml_opt_params
params
)
{
ggml_opt_context_t
result
=
new
struct
ggml_opt_context
;
result
->
backend_sched
=
params
.
backend_sched
;
result
->
ctx_compute
=
params
.
ctx_compute
;
result
->
loss_type
=
params
.
loss_type
;
result
->
build_type
=
params
.
build_type
;
result
->
build_type_alloc
=
params
.
build_type
;
result
->
inputs
=
params
.
inputs
;
result
->
outputs
=
params
.
outputs
;
result
->
opt_period
=
params
.
opt_period
;
result
->
get_opt_pars
=
params
.
get_opt_pars
;
result
->
get_opt_pars_ud
=
params
.
get_opt_pars_ud
;
GGML_ASSERT
(
result
->
opt_period
>=
1
);
result
->
static_graphs
=
result
->
ctx_compute
;
if
(
!
result
->
static_graphs
)
{
GGML_ASSERT
(
!
result
->
inputs
);
GGML_ASSERT
(
!
result
->
outputs
);
return
result
;
}
GGML_ASSERT
(
result
->
inputs
);
GGML_ASSERT
(
result
->
outputs
);
result
->
buf_static_cpu
=
ggml_backend_alloc_ctx_tensors_from_buft
(
result
->
ctx_static_cpu
,
ggml_backend_cpu_buffer_type
());
result
->
gf
=
ggml_new_graph_custom
(
result
->
ctx_compute
,
GGML_DEFAULT_GRAPH_SIZE
,
/*grads =*/
true
);
// Forward pass.
ggml_build_forward_expand
(
result
->
gf
,
result
->
outputs
);
ggml_
graph_reset
(
result
->
gb_op
t
);
ggml_
opt_build
(
resul
t
);
return
result
;
}
...
...
@@ -464,9 +561,9 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) {
return
;
}
ggml_backend_buffer_free
(
opt_ctx
->
buf_static
);
ggml_backend_buffer_free
(
opt_ctx
->
buf_
static_
cpu
);
ggml_backend_buffer_free
(
opt_ctx
->
buf_cpu
);
ggml_free
(
opt_ctx
->
ctx_static
);
ggml_free
(
opt_ctx
->
ctx_
static_
cpu
);
ggml_free
(
opt_ctx
->
ctx_cpu
);
delete
opt_ctx
;
}
...
...
@@ -582,8 +679,79 @@ void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, doubl
// ====== Computation ======
static
void
ggml_opt_eval_graph
(
ggml_opt_context_t
opt_ctx
,
ggml_cgraph
*
graph
,
ggml_opt_result
*
result
)
{
if
(
graph
!=
opt_ctx
->
gf
)
{
void
ggml_opt_prepare_alloc
(
ggml_opt_context_t
opt_ctx
,
struct
ggml_context
*
ctx_compute
,
struct
ggml_cgraph
*
gf
,
struct
ggml_tensor
*
inputs
,
struct
ggml_tensor
*
outputs
)
{
GGML_ASSERT
(
!
opt_ctx
->
static_graphs
);
opt_ctx
->
ctx_compute
=
ctx_compute
;
opt_ctx
->
gf
=
gf
;
opt_ctx
->
inputs
=
inputs
;
opt_ctx
->
outputs
=
outputs
;
}
void
ggml_opt_alloc
(
ggml_opt_context_t
opt_ctx
,
bool
backward
)
{
GGML_ASSERT
(
!
opt_ctx
->
eval_ready
);
if
(
opt_ctx
->
build_type
==
GGML_OPT_BUILD_TYPE_OPT
&&
opt_ctx
->
opt_period
>
1
&&
opt_ctx
->
opt_i
==
0
)
{
ggml_graph_reset
(
opt_ctx
->
gb_grad
);
}
if
(
backward
)
{
const
int32_t
opt_i_next
=
(
opt_ctx
->
opt_i
+
1
)
%
opt_ctx
->
opt_period
;
opt_ctx
->
build_type
=
opt_i_next
==
0
?
GGML_OPT_BUILD_TYPE_OPT
:
GGML_OPT_BUILD_TYPE_GRAD
;
}
else
{
opt_ctx
->
build_type
=
GGML_OPT_BUILD_TYPE_FORWARD
;
}
if
(
!
opt_ctx
->
static_graphs
)
{
ggml_opt_build
(
opt_ctx
);
}
struct
ggml_cgraph
*
graph
=
nullptr
;
switch
(
opt_ctx
->
build_type
)
{
case
GGML_OPT_BUILD_TYPE_FORWARD
:
{
graph
=
opt_ctx
->
gf
;
}
break
;
case
GGML_OPT_BUILD_TYPE_GRAD
:
{
graph
=
opt_ctx
->
gb_grad
;
}
break
;
case
GGML_OPT_BUILD_TYPE_OPT
:
{
graph
=
opt_ctx
->
gb_opt
;
}
break
;
}
GGML_ASSERT
(
graph
);
if
(
opt_ctx
->
allocated_graph
==
graph
)
{
opt_ctx
->
eval_ready
=
true
;
return
;
}
ggml_backend_sched_reset
(
opt_ctx
->
backend_sched
);
// clear allocation of previous graph
if
(
opt_ctx
->
static_graphs
)
{
ggml_init_params
params
=
{
/*.mem_size =*/
graph
->
size
*
ggml_tensor_overhead
()
+
ggml_graph_overhead_custom
(
graph
->
size
,
graph
->
grads
),
/*.mem_buffer =*/
nullptr
,
/*.no_alloc =*/
true
,
};
ggml_free
(
opt_ctx
->
ctx_copy
);
opt_ctx
->
ctx_copy
=
ggml_init
(
params
);
opt_ctx
->
allocated_graph_copy
=
dup_graph
(
opt_ctx
->
ctx_copy
,
graph
);
}
else
{
opt_ctx
->
allocated_graph_copy
=
graph
;
}
ggml_backend_sched_alloc_graph
(
opt_ctx
->
backend_sched
,
opt_ctx
->
allocated_graph_copy
);
opt_ctx
->
allocated_graph
=
graph
;
opt_ctx
->
eval_ready
=
true
;
}
void
ggml_opt_eval
(
ggml_opt_context_t
opt_ctx
,
ggml_opt_result_t
result
)
{
GGML_ASSERT
(
opt_ctx
->
eval_ready
);
if
(
opt_ctx
->
allocated_graph
==
opt_ctx
->
gb_opt
)
{
struct
ggml_opt_optimizer_params
opt_pars
=
opt_ctx
->
get_opt_pars
(
opt_ctx
->
get_opt_pars_ud
);
GGML_ASSERT
(
opt_pars
.
adamw
.
alpha
>
0.0
f
);
...
...
@@ -609,9 +777,19 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
adamw_par_data
[
6
]
=
beta2h
;
}
ggml_opt_alloc_graph
(
opt_ctx
,
graph
);
ggml_backend_sched_graph_compute
(
opt_ctx
->
backend_sched
,
opt_ctx
->
allocated_graph_copy
);
opt_ctx
->
iter
+=
opt_ctx
->
allocated_graph
==
opt_ctx
->
gb_opt
;
opt_ctx
->
opt_i
=
(
opt_ctx
->
opt_i
+
1
)
%
opt_ctx
->
opt_period
;
if
(
!
opt_ctx
->
static_graphs
)
{
opt_ctx
->
gf
=
nullptr
;
opt_ctx
->
gb_grad
=
nullptr
;
opt_ctx
->
gb_opt
=
nullptr
;
opt_ctx
->
allocated_graph
=
nullptr
;
opt_ctx
->
allocated_graph_copy
=
nullptr
;
}
opt_ctx
->
eval_ready
=
false
;
if
(
!
result
)
{
return
;
...
...
@@ -635,12 +813,14 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
ggml_backend_tensor_get
(
opt_ctx
->
loss
,
&
loss
,
0
,
ggml_nbytes
(
opt_ctx
->
loss
));
result
->
loss
.
push_back
(
loss
);
if
(
opt_ctx
->
pred
)
{
GGML_ASSERT
(
opt_ctx
->
pred
->
type
==
GGML_TYPE_I32
);
std
::
vector
<
int32_t
>
pred
(
ndata
);
ggml_backend_tensor_get
(
opt_ctx
->
pred
,
pred
.
data
(),
0
,
ggml_nbytes
(
opt_ctx
->
pred
));
result
->
pred
.
insert
(
result
->
pred
.
end
(),
pred
.
begin
(),
pred
.
end
());
}
if
(
!
opt_ctx
->
labels
||
result
->
ncorrect
<
0
)
{
if
(
!
opt_ctx
->
ncorrect
||
result
->
ncorrect
<
0
)
{
result
->
ncorrect
=
-
1
;
return
;
}
...
...
@@ -652,26 +832,6 @@ static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph,
result
->
ncorrect
+=
ncorrect
;
}
void
ggml_opt_forward
(
ggml_opt_context_t
opt_ctx
,
ggml_opt_result
*
result
)
{
ggml_opt_eval_graph
(
opt_ctx
,
opt_ctx
->
gf
,
result
);
}
void
ggml_opt_forward_backward
(
ggml_opt_context_t
opt_ctx
,
ggml_opt_result
*
result
)
{
if
(
opt_ctx
->
opt_period
==
1
)
{
ggml_opt_eval_graph
(
opt_ctx
,
opt_ctx
->
gb_opt
,
result
);
return
;
}
const
int32_t
opt_i_next
=
(
opt_ctx
->
opt_i
+
1
)
%
opt_ctx
->
opt_period
;
if
(
opt_i_next
==
0
)
{
ggml_opt_eval_graph
(
opt_ctx
,
opt_ctx
->
gb_opt
,
result
);
ggml_opt_reset
(
opt_ctx
,
/*optimizer =*/
false
);
}
else
{
ggml_opt_eval_graph
(
opt_ctx
,
opt_ctx
->
gb_grad
,
result
);
}
opt_ctx
->
opt_i
=
opt_i_next
;
}
// ====== High-Level Functions ======
void
ggml_opt_epoch
(
...
...
@@ -700,16 +860,18 @@ void ggml_opt_epoch(
int64_t
ibatch
=
0
;
int64_t
t_loop_start
=
ggml_time_us
();
for
(;
ibatch
<
ibatch_split
;
++
ibatch
)
{
ggml_opt_alloc
(
opt_ctx
,
/*backward =*/
true
);
ggml_opt_dataset_get_batch
(
dataset
,
inputs
,
labels
,
ibatch
);
ggml_opt_
forward_backward
(
opt_ctx
,
result_train
);
ggml_opt_
eval
(
opt_ctx
,
result_train
);
if
(
callback_train
)
{
callback_train
(
true
,
opt_ctx
,
dataset
,
result_train
,
ibatch
+
1
,
ibatch_split
,
t_loop_start
);
}
}
t_loop_start
=
ggml_time_us
();
for
(;
ibatch
<
nbatches
;
++
ibatch
)
{
ggml_opt_alloc
(
opt_ctx
,
/*backward =*/
false
);
ggml_opt_dataset_get_batch
(
dataset
,
inputs
,
labels
,
ibatch
);
ggml_opt_
forward
(
opt_ctx
,
result_eval
);
ggml_opt_
eval
(
opt_ctx
,
result_eval
);
if
(
callback_eval
)
{
callback_eval
(
false
,
opt_ctx
,
dataset
,
result_eval
,
ibatch
+
1
-
ibatch_split
,
nbatches
-
ibatch_split
,
t_loop_start
);
}
...
...
@@ -726,13 +888,26 @@ void ggml_opt_epoch_callback_progress_bar(
int64_t
t_start_us
)
{
fprintf
(
stderr
,
"%s["
,
train
?
"train: "
:
"val: "
);
constexpr
int64_t
bar_length
=
25
;
// The progress bar consists of partially filled blocks, unicode has 8 separate fill levels.
constexpr
int64_t
bar_length
=
8
;
const
int64_t
ibatch8
=
8
*
ibatch
;
for
(
int64_t
j
=
0
;
j
<
bar_length
;
++
j
)
{
const
int64_t
ibatch_j
=
ibatch_max
*
j
/
bar_length
;
if
(
ibatch_j
<
ibatch
)
{
fprintf
(
stderr
,
"="
);
}
else
if
(
ibatch_max
*
(
j
-
1
)
/
bar_length
<
ibatch
)
{
fprintf
(
stderr
,
">"
);
if
(
ibatch_max
*
(
8
*
j
+
8
)
/
bar_length
<
ibatch8
)
{
fprintf
(
stderr
,
"\u2588"
);
// full block
}
else
if
(
ibatch_max
*
(
8
*
j
+
7
)
/
bar_length
<
ibatch8
)
{
fprintf
(
stderr
,
"\u2589"
);
// 7/8 filled
}
else
if
(
ibatch_max
*
(
8
*
j
+
6
)
/
bar_length
<
ibatch8
)
{
fprintf
(
stderr
,
"\u258A"
);
// 6/8 filled
}
else
if
(
ibatch_max
*
(
8
*
j
+
5
)
/
bar_length
<
ibatch8
)
{
fprintf
(
stderr
,
"\u258B"
);
// 5/8 filled
}
else
if
(
ibatch_max
*
(
8
*
j
+
4
)
/
bar_length
<
ibatch8
)
{
fprintf
(
stderr
,
"\u258C"
);
// 4/8 filled
}
else
if
(
ibatch_max
*
(
8
*
j
+
3
)
/
bar_length
<
ibatch8
)
{
fprintf
(
stderr
,
"\u258D"
);
// 3/8 filled
}
else
if
(
ibatch_max
*
(
8
*
j
+
2
)
/
bar_length
<
ibatch8
)
{
fprintf
(
stderr
,
"\u258E"
);
// 2/8 filled
}
else
if
(
ibatch_max
*
(
8
*
j
+
1
)
/
bar_length
<
ibatch8
)
{
fprintf
(
stderr
,
"\u258F"
);
// 1/8 filled
}
else
{
fprintf
(
stderr
,
" "
);
}
...
...
@@ -764,8 +939,8 @@ void ggml_opt_epoch_callback_progress_bar(
const
int64_t
t_eta_m
=
t_eta_s
/
60
;
t_eta_s
-=
t_eta_m
*
60
;
fprintf
(
stderr
,
"
|
data=%0
6
"
PRId64
"/%0
6
"
PRId64
"
,
loss=%.
6
lf
+-
%.
6
lf
,
acc
uracy
=%.2lf
+-
%.2lf%%
,
"
"t=%02"
PRId64
":%02"
PRId64
":%02"
PRId64
"
,
ETA=%02"
PRId64
":%02"
PRId64
":%02"
PRId64
"
]
\r
"
,
fprintf
(
stderr
,
"
]
data=%0
7
"
PRId64
"/%0
7
"
PRId64
" loss=%.
5
lf
±
%.
5
lf acc=%.2lf
±
%.2lf%% "
"t=%02"
PRId64
":%02"
PRId64
":%02"
PRId64
" ETA=%02"
PRId64
":%02"
PRId64
":%02"
PRId64
"
\r
"
,
idata
,
idata_max
,
loss
,
loss_unc
,
100.0
*
accuracy
,
100.0
*
accuracy_unc
,
t_ibatch_h
,
t_ibatch_m
,
t_ibatch_s
,
t_eta_h
,
t_eta_m
,
t_eta_s
);
if
(
ibatch
==
ibatch_max
)
{
...
...
@@ -806,7 +981,10 @@ void ggml_opt_fit(
int64_t
epoch
=
1
;
ggml_opt_params
params
=
ggml_opt_default_params
(
backend_sched
,
ctx_compute
,
inputs
,
outputs
,
loss_type
);
ggml_opt_params
params
=
ggml_opt_default_params
(
backend_sched
,
loss_type
);
params
.
ctx_compute
=
ctx_compute
;
params
.
inputs
=
inputs
;
params
.
outputs
=
outputs
;
params
.
opt_period
=
opt_period
;
params
.
get_opt_pars
=
get_opt_pars
;
params
.
get_opt_pars_ud
=
&
epoch
;
...
...
ml/backend/ggml/ggml/src/ggml-quants.c
View file @
0cefd46f
...
...
@@ -19,12 +19,6 @@
#define GROUP_MAX_EPS_IQ1_M 1e-7f
#define GROUP_MAX_EPS_IQ1_S 1e-12f
#if defined(_MSC_VER)
// disable "possible loss of data" to avoid warnings for hundreds of casts
// we should just be careful :)
#pragma warning(disable: 4244 4267)
#endif
#define UNUSED GGML_UNUSED
// reference implementation for deterministic creation of model files
...
...
ml/backend/ggml/ggml/src/ggml.c
View file @
0cefd46f
...
...
@@ -1301,6 +1301,10 @@ bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
return
ggml_is_contiguous_n
(
tensor
,
2
);
}
bool
ggml_is_contiguously_allocated
(
const
struct
ggml_tensor
*
tensor
)
{
return
ggml_nbytes
(
tensor
)
==
ggml_nelements
(
tensor
)
*
ggml_type_size
(
tensor
->
type
)
/
ggml_blck_size
(
tensor
->
type
);
}
bool
ggml_is_permuted
(
const
struct
ggml_tensor
*
tensor
)
{
static_assert
(
GGML_MAX_DIMS
==
4
,
"GGML_MAX_DIMS is not 4 - update this function"
);
...
...
@@ -2730,11 +2734,11 @@ void ggml_mul_mat_set_prec(
c = ggml_mul_mat_id(ctx, as, b, ids);
as -> [cols, rows, n_expert]
ids -> [n_experts_used, n_tokens] (i32)
b -> [cols, n_expert_used, n_tokens]
ids -> [n_expert_used, n_tokens] (i32)
c -> [rows, n_expert_used, n_tokens]
in b, n_expert
s
_used can be broadcasted to match the n_expert_used of ids
in b, n_expert_used can be broadcasted to match the n_expert_used of ids
c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
*/
...
...
@@ -5516,7 +5520,7 @@ static void ggml_compute_backward(
// tensor = src0 * 1 + src1 * 0
if
(
src0_needs_grads
)
{
// dsrc0 = dtensor * 1
ggml_add_or_set
(
ctx
,
cgraph
,
isrc0
,
g
rad
);
ggml_add_or_set
(
ctx
,
cgraph
,
isrc0
,
g
gml_reshape
(
ctx
,
grad
,
src0
)
);
}
if
(
src1_needs_grads
)
{
// dsrc1 = dtensor * 0 -> noop
...
...
@@ -5797,10 +5801,9 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
}
void
ggml_build_backward_expand
(
struct
ggml_context
*
ctx_static
,
struct
ggml_context
*
ctx_compute
,
struct
ggml_context
*
ctx
,
struct
ggml_cgraph
*
cgraph
,
bool
accumulate
)
{
struct
ggml_tensor
**
grad_accs
)
{
GGML_ASSERT
(
cgraph
->
n_nodes
>
0
);
GGML_ASSERT
(
cgraph
->
grads
);
GGML_ASSERT
(
cgraph
->
grad_accs
);
...
...
@@ -5873,21 +5876,24 @@ void ggml_build_backward_expand(
GGML_ASSERT
(
!
node
->
view_src
||
node
->
op
==
GGML_OP_CPY
||
node
->
op
==
GGML_OP_VIEW
||
node
->
op
==
GGML_OP_RESHAPE
||
node
->
op
==
GGML_OP_PERMUTE
||
node
->
op
==
GGML_OP_TRANSPOSE
);
const
size_t
igrad
=
ggml_hash_find
(
&
cgraph
->
visited_hash_set
,
node
);
GGML_ASSERT
(
igrad
!=
GGML_HASHSET_FULL
);
GGML_ASSERT
(
ggml_bitset_get
(
cgraph
->
visited_hash_set
.
used
,
igrad
));
if
((
accumulate
&&
(
node
->
flags
&
GGML_TENSOR_FLAG_PARAM
))
||
(
node
->
flags
&
GGML_TENSOR_FLAG_LOSS
))
{
cgraph
->
grad_accs
[
igrad
]
=
ggml_dup_tensor
(
ctx_static
,
node
);
cgraph
->
grads
[
igrad
]
=
cgraph
->
grad_accs
[
igrad
];
ggml_format_name
(
cgraph
->
grad_accs
[
igrad
],
"grad acc for %s"
,
node
->
name
);
const
size_t
ihash
=
ggml_hash_find
(
&
cgraph
->
visited_hash_set
,
node
);
GGML_ASSERT
(
ihash
!=
GGML_HASHSET_FULL
);
GGML_ASSERT
(
ggml_bitset_get
(
cgraph
->
visited_hash_set
.
used
,
ihash
));
if
(
grad_accs
&&
grad_accs
[
i
])
{
cgraph
->
grad_accs
[
ihash
]
=
grad_accs
[
i
];
cgraph
->
grads
[
ihash
]
=
cgraph
->
grad_accs
[
ihash
];
}
else
if
(
node
->
flags
&
GGML_TENSOR_FLAG_LOSS
)
{
// loss tensors always need a gradient accumulator
cgraph
->
grad_accs
[
ihash
]
=
ggml_new_tensor
(
ctx
,
GGML_TYPE_F32
,
GGML_MAX_DIMS
,
node
->
ne
);
cgraph
->
grads
[
ihash
]
=
cgraph
->
grad_accs
[
ihash
];
}
grads_needed
[
i
grad
]
=
true
;
grads_needed
[
i
hash
]
=
true
;
}
for
(
int
i
=
n_nodes_f
-
1
;
i
>=
0
;
--
i
)
{
// inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
// use allocator to automatically make inplace operations
ggml_compute_backward
(
ctx
_compute
,
cgraph
,
i
,
grads_needed
);
ggml_compute_backward
(
ctx
,
cgraph
,
i
,
grads_needed
);
}
free
(
grads_needed
);
...
...
@@ -6033,8 +6039,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
}
}
struct
ggml_cgraph
*
ggml_graph_dup
(
struct
ggml_context
*
ctx
,
struct
ggml_cgraph
*
cgraph
)
{
struct
ggml_cgraph
*
result
=
ggml_new_graph_custom
(
ctx
,
cgraph
->
size
,
cgraph
->
grads
!=
NULL
);
struct
ggml_cgraph
*
ggml_graph_dup
(
struct
ggml_context
*
ctx
,
struct
ggml_cgraph
*
cgraph
,
bool
force_grads
)
{
struct
ggml_cgraph
*
result
=
ggml_new_graph_custom
(
ctx
,
cgraph
->
size
,
cgraph
->
grads
||
force_grads
);
ggml_graph_cpy
(
cgraph
,
result
);
return
result
;
}
...
...
@@ -6053,6 +6059,9 @@ struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
}
void
ggml_graph_reset
(
struct
ggml_cgraph
*
cgraph
)
{
if
(
!
cgraph
)
{
return
;
}
GGML_ASSERT
(
cgraph
->
grads
!=
NULL
);
for
(
int
i
=
0
;
i
<
cgraph
->
n_nodes
;
i
++
)
{
...
...
@@ -6362,8 +6371,8 @@ void ggml_set_output(struct ggml_tensor * tensor) {
tensor
->
flags
|=
GGML_TENSOR_FLAG_OUTPUT
;
}
void
ggml_set_param
(
struct
ggml_context
*
ctx
,
struct
ggml_tensor
*
tensor
)
{
GGML_
UNUSED
(
ctx
);
// TODO: remove this parameter
void
ggml_set_param
(
struct
ggml_tensor
*
tensor
)
{
GGML_
ASSERT
(
tensor
->
op
==
GGML_OP_NONE
);
tensor
->
flags
|=
GGML_TENSOR_FLAG_PARAM
;
}
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment