Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
544b6739
Unverified
Commit
544b6739
authored
Nov 06, 2025
by
Daniel Hiltgen
Committed by
GitHub
Nov 06, 2025
Browse files
ggml update to b6840 (#12791)
parent
c4ba257c
Changes
103
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
991 additions
and
566 deletions
+991
-566
ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu
ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu
+100
-4
ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
+1
-1
ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
+0
-7
ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
+61
-172
ml/backend/ggml/ggml/src/ggml-cuda/cpy.cuh
ml/backend/ggml/ggml/src/ggml-cuda/cpy.cuh
+1
-5
ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
+1
-0
ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh
ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh
+2
-7
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
+12
-7
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
+44
-57
ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu
ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu
+40
-6
ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh
ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh
+313
-31
ml/backend/ggml/ggml/src/ggml-cuda/mmid.cu
ml/backend/ggml/ggml/src/ggml-cuda/mmid.cu
+164
-0
ml/backend/ggml/ggml/src/ggml-cuda/mmid.cuh
ml/backend/ggml/ggml/src/ggml-cuda/mmid.cuh
+5
-0
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
+3
-166
ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu
ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu
+44
-28
ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu
ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu
+119
-68
ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cuh
ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cuh
+4
-3
ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
+4
-2
ml/backend/ggml/ggml/src/ggml-impl.h
ml/backend/ggml/ggml/src/ggml-impl.h
+48
-2
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp
+25
-0
No files found.
ml/backend/ggml/ggml/src/ggml-cuda/argsort.cu
View file @
544b6739
#include "argsort.cuh"
#ifdef GGML_CUDA_USE_CUB
# include <cub/cub.cuh>
using
namespace
cub
;
#endif // GGML_CUDA_USE_CUB
static
__global__
void
init_indices
(
int
*
indices
,
const
int
ncols
,
const
int
nrows
)
{
const
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
row
=
blockIdx
.
y
;
if
(
col
<
ncols
&&
row
<
nrows
)
{
indices
[
row
*
ncols
+
col
]
=
col
;
}
}
static
__global__
void
init_offsets
(
int
*
offsets
,
const
int
ncols
,
const
int
nrows
)
{
const
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
idx
<=
nrows
)
{
offsets
[
idx
]
=
idx
*
ncols
;
}
}
#ifdef GGML_CUDA_USE_CUB
static
void
argsort_f32_i32_cuda_cub
(
ggml_cuda_pool
&
pool
,
const
float
*
x
,
int
*
dst
,
const
int
ncols
,
const
int
nrows
,
ggml_sort_order
order
,
cudaStream_t
stream
)
{
ggml_cuda_pool_alloc
<
int
>
temp_indices_alloc
(
pool
,
ncols
*
nrows
);
ggml_cuda_pool_alloc
<
float
>
temp_keys_alloc
(
pool
,
ncols
*
nrows
);
ggml_cuda_pool_alloc
<
int
>
offsets_alloc
(
pool
,
nrows
+
1
);
int
*
temp_indices
=
temp_indices_alloc
.
get
();
float
*
temp_keys
=
temp_keys_alloc
.
get
();
int
*
d_offsets
=
offsets_alloc
.
get
();
static
const
int
block_size
=
256
;
const
dim3
grid_size
((
ncols
+
block_size
-
1
)
/
block_size
,
nrows
);
init_indices
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
temp_indices
,
ncols
,
nrows
);
const
dim3
offset_grid
((
nrows
+
block_size
-
1
)
/
block_size
);
init_offsets
<<<
offset_grid
,
block_size
,
0
,
stream
>>>
(
d_offsets
,
ncols
,
nrows
);
cudaMemcpyAsync
(
temp_keys
,
x
,
ncols
*
nrows
*
sizeof
(
float
),
cudaMemcpyDeviceToDevice
,
stream
);
size_t
temp_storage_bytes
=
0
;
if
(
order
==
GGML_SORT_ORDER_ASC
)
{
DeviceSegmentedRadixSort
::
SortPairs
(
nullptr
,
temp_storage_bytes
,
temp_keys
,
temp_keys
,
// keys (in-place)
temp_indices
,
dst
,
// values (indices)
ncols
*
nrows
,
nrows
,
// num items, num segments
d_offsets
,
d_offsets
+
1
,
0
,
sizeof
(
float
)
*
8
,
// all bits
stream
);
}
else
{
DeviceSegmentedRadixSort
::
SortPairsDescending
(
nullptr
,
temp_storage_bytes
,
temp_keys
,
temp_keys
,
temp_indices
,
dst
,
ncols
*
nrows
,
nrows
,
d_offsets
,
d_offsets
+
1
,
0
,
sizeof
(
float
)
*
8
,
stream
);
}
ggml_cuda_pool_alloc
<
uint8_t
>
temp_storage_alloc
(
pool
,
temp_storage_bytes
);
void
*
d_temp_storage
=
temp_storage_alloc
.
get
();
if
(
order
==
GGML_SORT_ORDER_ASC
)
{
DeviceSegmentedRadixSort
::
SortPairs
(
d_temp_storage
,
temp_storage_bytes
,
temp_keys
,
temp_keys
,
temp_indices
,
dst
,
ncols
*
nrows
,
nrows
,
d_offsets
,
d_offsets
+
1
,
0
,
sizeof
(
float
)
*
8
,
stream
);
}
else
{
DeviceSegmentedRadixSort
::
SortPairsDescending
(
d_temp_storage
,
temp_storage_bytes
,
temp_keys
,
temp_keys
,
temp_indices
,
dst
,
ncols
*
nrows
,
nrows
,
d_offsets
,
d_offsets
+
1
,
0
,
sizeof
(
float
)
*
8
,
stream
);
}
}
#endif // GGML_CUDA_USE_CUB
// Bitonic sort implementation
template
<
typename
T
>
static
inline
__device__
void
ggml_cuda_swap
(
T
&
a
,
T
&
b
)
{
T
tmp
=
a
;
...
...
@@ -65,7 +141,12 @@ static int next_power_of_2(int x) {
return
n
;
}
static
void
argsort_f32_i32_cuda
(
const
float
*
x
,
int
*
dst
,
const
int
ncols
,
const
int
nrows
,
ggml_sort_order
order
,
cudaStream_t
stream
)
{
static
void
argsort_f32_i32_cuda_bitonic
(
const
float
*
x
,
int
*
dst
,
const
int
ncols
,
const
int
nrows
,
ggml_sort_order
order
,
cudaStream_t
stream
)
{
// bitonic sort requires ncols to be power of 2
const
int
ncols_pad
=
next_power_of_2
(
ncols
);
...
...
@@ -77,9 +158,11 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
GGML_ASSERT
(
shared_mem
<=
ggml_cuda_info
().
devices
[
ggml_cuda_get_device
()].
smpb
);
if
(
order
==
GGML_SORT_ORDER_ASC
)
{
k_argsort_f32_i32
<
GGML_SORT_ORDER_ASC
><<<
block_nums
,
block_dims
,
shared_mem
,
stream
>>>
(
x
,
dst
,
ncols
,
ncols_pad
);
k_argsort_f32_i32
<
GGML_SORT_ORDER_ASC
>
<<<
block_nums
,
block_dims
,
shared_mem
,
stream
>>>
(
x
,
dst
,
ncols
,
ncols_pad
);
}
else
if
(
order
==
GGML_SORT_ORDER_DESC
)
{
k_argsort_f32_i32
<
GGML_SORT_ORDER_DESC
><<<
block_nums
,
block_dims
,
shared_mem
,
stream
>>>
(
x
,
dst
,
ncols
,
ncols_pad
);
k_argsort_f32_i32
<
GGML_SORT_ORDER_DESC
>
<<<
block_nums
,
block_dims
,
shared_mem
,
stream
>>>
(
x
,
dst
,
ncols
,
ncols_pad
);
}
else
{
GGML_ABORT
(
"fatal error"
);
}
...
...
@@ -197,6 +280,19 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if
(
src0
->
type
==
GGML_TYPE_I32
)
{
argsort_i32_i32_cuda
((
const
int32_t
*
)
src0_d
,
(
int
*
)
dst_d
,
ncols
,
nrows
,
order
,
stream
);
}
else
{
argsort_f32_i32_cuda
(
src0_d
,
(
int
*
)
dst_d
,
ncols
,
nrows
,
order
,
stream
);
#ifdef GGML_CUDA_USE_CUB
const
int
ncols_pad
=
next_power_of_2
(
ncols
);
const
size_t
shared_mem
=
ncols_pad
*
sizeof
(
int
);
const
size_t
max_shared_mem
=
ggml_cuda_info
().
devices
[
ggml_cuda_get_device
()].
smpb
;
if
(
shared_mem
>
max_shared_mem
||
ncols
>
1024
)
{
ggml_cuda_pool
&
pool
=
ctx
.
pool
();
argsort_f32_i32_cuda_cub
(
pool
,
src0_d
,
(
int
*
)
dst_d
,
ncols
,
nrows
,
order
,
stream
);
}
else
{
argsort_f32_i32_cuda_bitonic
(
src0_d
,
(
int
*
)
dst_d
,
ncols
,
nrows
,
order
,
stream
);
}
#else
argsort_f32_i32_cuda_bitonic
(
src0_d
,
(
int
*
)
dst_d
,
ncols
,
nrows
,
order
,
stream
);
#endif
}
}
ml/backend/ggml/ggml/src/ggml-cuda/binbcast.cu
View file @
544b6739
...
...
@@ -272,7 +272,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
const
uint3
ne12
=
init_fastdiv_values
((
uint32_t
)
cne1
[
2
]);
const
uint3
ne13
=
init_fastdiv_values
((
uint32_t
)
cne1
[
3
]);
if
(
block_nums
.
z
>
65535
)
{
if
(
block_nums
.
z
>
65535
||
block_nums
.
y
>
65535
)
{
int
block_num
=
(
ne0
*
ne1
*
ne2
*
ne3
+
block_size
-
1
)
/
block_size
;
const
uint3
prod_012
=
init_fastdiv_values
((
uint32_t
)
(
ne0
*
ne1
*
ne2
));
const
uint3
prod_01
=
init_fastdiv_values
((
uint32_t
)
(
ne0
*
ne1
));
...
...
ml/backend/ggml/ggml/src/ggml-cuda/common.cuh
View file @
544b6739
...
...
@@ -982,13 +982,6 @@ struct ggml_cuda_graph {
bool
disable_due_to_failed_graph_capture
=
false
;
int
number_consecutive_updates
=
0
;
std
::
vector
<
ggml_graph_node_properties
>
ggml_graph_properties
;
bool
use_cpy_indirection
=
false
;
std
::
vector
<
char
*>
cpy_dest_ptrs
;
char
**
dest_ptrs_d
;
int
dest_ptrs_size
=
0
;
// Index to allow each cpy kernel to be aware of it's position within the graph
// relative to other cpy nodes.
int
graph_cpynode_index
=
-
1
;
#endif
};
...
...
ml/backend/ggml/ggml/src/ggml-cuda/cpy.cu
View file @
544b6739
...
...
@@ -8,18 +8,16 @@
typedef
void
(
*
cpy_kernel_t
)(
const
char
*
cx
,
char
*
cdst
);
template
<
cpy_kernel_t
cpy_1
>
static
__global__
void
cpy_flt
(
const
char
*
cx
,
char
*
cdst
_direct
,
const
int
ne
,
static
__global__
void
cpy_flt
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
char
**
cdst_indirect
,
int
graph_cpynode_index
)
{
const
int
nb12
,
const
int
nb13
)
{
const
int64_t
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
i
>=
ne
)
{
return
;
}
char
*
cdst
=
(
cdst_indirect
!=
nullptr
)
?
cdst_indirect
[
graph_cpynode_index
]
:
cdst_direct
;
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
// then combine those indices with the corresponding byte offsets to get the total offsets
const
int64_t
i03
=
i
/
(
ne00
*
ne01
*
ne02
);
...
...
@@ -63,18 +61,16 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
}
template
<
cpy_kernel_t
cpy_blck
,
int
qk
>
static
__global__
void
cpy_f32_q
(
const
char
*
cx
,
char
*
cdst
_direct
,
const
int
ne
,
static
__global__
void
cpy_f32_q
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
char
**
cdst_indirect
,
int
graph_cpynode_index
)
{
const
int
nb12
,
const
int
nb13
)
{
const
int
i
=
(
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
)
*
qk
;
if
(
i
>=
ne
)
{
return
;
}
char
*
cdst
=
(
cdst_indirect
!=
nullptr
)
?
cdst_indirect
[
graph_cpynode_index
]
:
cdst_direct
;
const
int
i03
=
i
/
(
ne00
*
ne01
*
ne02
);
const
int
i02
=
(
i
-
i03
*
ne00
*
ne01
*
ne02
)
/
(
ne00
*
ne01
);
const
int
i01
=
(
i
-
i03
*
ne00
*
ne01
*
ne02
-
i02
*
ne01
*
ne00
)
/
ne00
;
...
...
@@ -91,18 +87,16 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst_direct, const int
}
template
<
cpy_kernel_t
cpy_blck
,
int
qk
>
static
__global__
void
cpy_q_f32
(
const
char
*
cx
,
char
*
cdst
_direct
,
const
int
ne
,
static
__global__
void
cpy_q_f32
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
char
**
cdst_indirect
,
int
graph_cpynode_index
)
{
const
int
nb12
,
const
int
nb13
)
{
const
int
i
=
(
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
)
*
qk
;
if
(
i
>=
ne
)
{
return
;
}
char
*
cdst
=
(
cdst_indirect
!=
nullptr
)
?
cdst_indirect
[
graph_cpynode_index
]
:
cdst_direct
;
const
int
i03
=
i
/
(
ne00
*
ne01
*
ne02
);
const
int
i02
=
(
i
-
i03
*
ne00
*
ne01
*
ne02
)
/
(
ne00
*
ne01
);
const
int
i01
=
(
i
-
i03
*
ne00
*
ne01
*
ne02
-
i02
*
ne01
*
ne00
)
/
ne00
;
...
...
@@ -118,67 +112,47 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int
cpy_blck
(
cx
+
x_offset
,
cdst
+
dst_offset
);
}
// Copy destination pointers to GPU to be available when pointer indirection is in use
void
ggml_cuda_cpy_dest_ptrs_copy
(
ggml_cuda_graph
*
cuda_graph
,
char
**
host_dest_ptrs
,
const
int
host_dest_ptrs_size
,
cudaStream_t
stream
)
{
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if
(
cuda_graph
->
dest_ptrs_size
<
host_dest_ptrs_size
)
{
// (re-)allocate GPU memory for destination pointers
CUDA_CHECK
(
cudaStreamSynchronize
(
stream
));
if
(
cuda_graph
->
dest_ptrs_d
!=
nullptr
)
{
CUDA_CHECK
(
cudaFree
(
cuda_graph
->
dest_ptrs_d
));
}
CUDA_CHECK
(
cudaMalloc
(
&
cuda_graph
->
dest_ptrs_d
,
host_dest_ptrs_size
*
sizeof
(
char
*
)));
cuda_graph
->
dest_ptrs_size
=
host_dest_ptrs_size
;
}
// copy destination pointers to GPU
CUDA_CHECK
(
cudaMemcpyAsync
(
cuda_graph
->
dest_ptrs_d
,
host_dest_ptrs
,
host_dest_ptrs_size
*
sizeof
(
char
*
),
cudaMemcpyHostToDevice
,
stream
));
cuda_graph
->
graph_cpynode_index
=
0
;
// reset index
#else
GGML_UNUSED_VARS
(
cuda_graph
,
host_dest_ptrs
,
host_dest_ptrs_size
,
stream
);
#endif
}
template
<
typename
src_t
,
typename
dst_t
>
static
void
ggml_cpy_flt_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
,
char
**
cdst_indirect
,
int
&
graph_cpynode_index
)
{
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
const
int
num_blocks
=
(
ne
+
CUDA_CPY_BLOCK_SIZE
-
1
)
/
CUDA_CPY_BLOCK_SIZE
;
cpy_flt
<
cpy_1_flt
<
src_t
,
dst_t
>><<<
num_blocks
,
CUDA_CPY_BLOCK_SIZE
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
cdst_indirect
,
graph_cpynode_index
++
);
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f32_q8_0_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
,
char
**
cdst_indirect
,
int
&
graph_cpynode_index
)
{
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ne
%
QK8_0
==
0
);
const
int
num_blocks
=
ne
/
QK8_0
;
cpy_f32_q
<
cpy_blck_f32_q8_0
,
QK8_0
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
cdst_indirect
,
graph_cpynode_index
++
);
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_q8_0_f32_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
,
char
**
cdst_indirect
,
int
&
graph_cpynode_index
)
{
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
const
int
num_blocks
=
ne
;
cpy_q_f32
<
cpy_blck_q8_0_f32
,
QK8_0
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
cdst_indirect
,
graph_cpynode_index
++
);
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f32_q4_0_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
,
char
**
cdst_indirect
,
int
&
graph_cpynode_index
)
{
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ne
%
QK4_0
==
0
);
const
int
num_blocks
=
ne
/
QK4_0
;
cpy_f32_q
<
cpy_blck_f32_q4_0
,
QK4_0
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
cdst_indirect
,
graph_cpynode_index
++
);
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_q4_0_f32_cuda
(
...
...
@@ -187,22 +161,22 @@ static void ggml_cpy_q4_0_f32_cuda(
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
,
char
**
cdst_indirect
,
int
&
graph_cpynode_index
)
{
cudaStream_t
stream
)
{
const
int
num_blocks
=
ne
;
cpy_q_f32
<
cpy_blck_q_f32
<
dequantize_q4_0
,
QK4_0
>
,
QK4_0
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
cdst_indirect
,
graph_cpynode_index
++
);
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f32_q4_1_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
,
char
**
cdst_indirect
,
int
&
graph_cpynode_index
)
{
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ne
%
QK4_1
==
0
);
const
int
num_blocks
=
ne
/
QK4_1
;
cpy_f32_q
<
cpy_blck_f32_q4_1
,
QK4_1
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
cdst_indirect
,
graph_cpynode_index
++
);
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_q4_1_f32_cuda
(
...
...
@@ -211,22 +185,22 @@ static void ggml_cpy_q4_1_f32_cuda(
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
,
char
**
cdst_indirect
,
int
&
graph_cpynode_index
)
{
cudaStream_t
stream
)
{
const
int
num_blocks
=
ne
;
cpy_q_f32
<
cpy_blck_q_f32
<
dequantize_q4_1
,
QK4_1
>
,
QK4_1
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
cdst_indirect
,
graph_cpynode_index
++
);
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f32_q5_0_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
,
char
**
cdst_indirect
,
int
&
graph_cpynode_index
)
{
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ne
%
QK5_0
==
0
);
const
int
num_blocks
=
ne
/
QK5_0
;
cpy_f32_q
<
cpy_blck_f32_q5_0
,
QK5_0
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
cdst_indirect
,
graph_cpynode_index
++
);
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_q5_0_f32_cuda
(
...
...
@@ -235,22 +209,22 @@ static void ggml_cpy_q5_0_f32_cuda(
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
,
char
**
cdst_indirect
,
int
&
graph_cpynode_index
)
{
cudaStream_t
stream
)
{
const
int
num_blocks
=
ne
;
cpy_q_f32
<
cpy_blck_q_f32
<
dequantize_q5_0
,
QK5_0
>
,
QK5_0
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
cdst_indirect
,
graph_cpynode_index
++
);
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f32_q5_1_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
,
char
**
cdst_indirect
,
int
&
graph_cpynode_index
)
{
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ne
%
QK5_1
==
0
);
const
int
num_blocks
=
ne
/
QK5_1
;
cpy_f32_q
<
cpy_blck_f32_q5_1
,
QK5_1
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
cdst_indirect
,
graph_cpynode_index
++
);
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_q5_1_f32_cuda
(
...
...
@@ -259,30 +233,29 @@ static void ggml_cpy_q5_1_f32_cuda(
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
,
char
**
cdst_indirect
,
int
&
graph_cpynode_index
)
{
cudaStream_t
stream
)
{
const
int
num_blocks
=
ne
;
cpy_q_f32
<
cpy_blck_q_f32
<
dequantize_q5_1
,
QK5_1
>
,
QK5_1
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
cdst_indirect
,
graph_cpynode_index
++
);
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
static
void
ggml_cpy_f32_iq4_nl_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
,
char
**
cdst_indirect
,
int
&
graph_cpynode_index
)
{
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ne
%
QK4_NL
==
0
);
const
int
num_blocks
=
ne
/
QK4_NL
;
cpy_f32_q
<
cpy_blck_f32_iq4_nl
,
QK4_NL
><<<
num_blocks
,
1
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
cdst_indirect
,
graph_cpynode_index
++
);
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
);
}
template
<
cpy_kernel_t
cpy_1
>
static
__global__
void
cpy_i32_i32
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
,
char
**
cdst_indirect
,
int
&
graph_cpynode_index
)
{
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
const
int64_t
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
...
...
@@ -302,23 +275,20 @@ static __global__ void cpy_i32_i32(
const
int64_t
i10
=
i
-
i13
*
ne10
*
ne11
*
ne12
-
i12
*
ne10
*
ne11
-
i11
*
ne10
;
const
int64_t
dst_offset
=
i10
*
nb10
+
i11
*
nb11
+
i12
*
nb12
+
i13
*
nb13
;
char
*
cdst_ptr
=
(
cdst_indirect
!=
nullptr
)
?
cdst_indirect
[
graph_cpynode_index
]
:
cdst
;
cpy_1
(
cx
+
x_offset
,
cdst_ptr
+
dst_offset
);
cpy_1
(
cx
+
x_offset
,
cdst
+
dst_offset
);
}
static
void
ggml_cpy_i32_i32_cuda
(
const
char
*
cx
,
char
*
cdst
,
const
int
ne
,
const
int
ne00
,
const
int
ne01
,
const
int
ne02
,
const
int
nb00
,
const
int
nb01
,
const
int
nb02
,
const
int
nb03
,
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
,
char
**
cdst_indirect
,
int
graph_cpynode_index
)
{
const
int
ne10
,
const
int
ne11
,
const
int
ne12
,
const
int
nb10
,
const
int
nb11
,
const
int
nb12
,
const
int
nb13
,
cudaStream_t
stream
)
{
const
int
num_blocks
=
(
ne
+
CUDA_CPY_BLOCK_SIZE
-
1
)
/
CUDA_CPY_BLOCK_SIZE
;
cpy_i32_i32
<
cpy_1_i32_i32
><<<
num_blocks
,
CUDA_CPY_BLOCK_SIZE
,
0
,
stream
>>>
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
stream
,
cdst_indirect
,
graph_cpynode_index
);
(
cx
,
cdst
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
stream
);
}
void
ggml_cuda_cpy
(
ggml_backend_cuda_context
&
ctx
,
const
ggml_tensor
*
src0
,
ggml_tensor
*
src1
,
bool
disable_indirection_for_this_node
)
{
void
ggml_cuda_cpy
(
ggml_backend_cuda_context
&
ctx
,
const
ggml_tensor
*
src0
,
ggml_tensor
*
src1
)
{
const
int64_t
ne
=
ggml_nelements
(
src0
);
GGML_ASSERT
(
ne
==
ggml_nelements
(
src1
));
...
...
@@ -352,16 +322,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char
*
src0_ddc
=
(
char
*
)
src0
->
data
;
char
*
src1_ddc
=
(
char
*
)
src1
->
data
;
char
**
dest_ptrs_d
=
nullptr
;
int
graph_cpynode_index
=
-
1
;
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if
(
ctx
.
cuda_graph
->
use_cpy_indirection
&&
!
disable_indirection_for_this_node
)
{
dest_ptrs_d
=
ctx
.
cuda_graph
->
dest_ptrs_d
;
graph_cpynode_index
=
ctx
.
cuda_graph
->
graph_cpynode_index
;
}
#else
GGML_UNUSED
(
disable_indirection_for_this_node
);
#endif
if
(
src0
->
type
==
src1
->
type
&&
ggml_is_contiguous
(
src0
)
&&
ggml_is_contiguous
(
src1
))
{
GGML_ASSERT
(
ggml_nbytes
(
src0
)
==
ggml_nbytes
(
src1
));
#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY)
...
...
@@ -370,136 +330,65 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
}
else
#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY
{
if
(
src0
->
type
==
GGML_TYPE_F32
)
{
ggml_cpy_flt_cuda
<
float
,
float
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
}
else
{
CUDA_CHECK
(
cudaMemcpyAsync
(
src1_ddc
,
src0_ddc
,
ggml_nbytes
(
src0
),
cudaMemcpyDeviceToDevice
,
main_stream
));
}
}
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_F32
)
{
ggml_cpy_flt_cuda
<
float
,
float
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_flt_cuda
<
float
,
float
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_BF16
)
{
ggml_cpy_flt_cuda
<
float
,
nv_bfloat16
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_flt_cuda
<
float
,
nv_bfloat16
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_F16
)
{
ggml_cpy_flt_cuda
<
float
,
half
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_flt_cuda
<
float
,
half
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q8_0
)
{
ggml_cpy_f32_q8_0_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_f32_q8_0_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_Q8_0
&&
src1
->
type
==
GGML_TYPE_F32
)
{
ggml_cpy_q8_0_f32_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_q8_0_f32_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q4_0
)
{
ggml_cpy_f32_q4_0_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_f32_q4_0_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_Q4_0
&&
src1
->
type
==
GGML_TYPE_F32
)
{
ggml_cpy_q4_0_f32_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q4_1
)
{
ggml_cpy_f32_q4_1_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_f32_q4_1_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_Q4_1
&&
src1
->
type
==
GGML_TYPE_F32
)
{
ggml_cpy_q4_1_f32_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q5_0
)
{
ggml_cpy_f32_q5_0_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_f32_q5_0_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_Q5_0
&&
src1
->
type
==
GGML_TYPE_F32
)
{
ggml_cpy_q5_0_f32_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_IQ4_NL
)
{
ggml_cpy_f32_iq4_nl_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_f32_iq4_nl_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q5_1
)
{
ggml_cpy_f32_q5_1_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_f32_q5_1_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_Q5_1
&&
src1
->
type
==
GGML_TYPE_F32
)
{
ggml_cpy_q5_1_f32_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_q5_1_f32_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F16
&&
src1
->
type
==
GGML_TYPE_F16
)
{
ggml_cpy_flt_cuda
<
half
,
half
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_flt_cuda
<
half
,
half
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F16
&&
src1
->
type
==
GGML_TYPE_BF16
)
{
ggml_cpy_flt_cuda
<
half
,
nv_bfloat16
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_flt_cuda
<
half
,
nv_bfloat16
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F16
&&
src1
->
type
==
GGML_TYPE_F32
)
{
ggml_cpy_flt_cuda
<
half
,
float
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_flt_cuda
<
half
,
float
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_I32
&&
src1
->
type
==
GGML_TYPE_I32
)
{
ggml_cpy_i32_i32_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
// TODO consider converting to template
ggml_cpy_i32_i32_cuda
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_BF16
&&
src1
->
type
==
GGML_TYPE_BF16
)
{
ggml_cpy_flt_cuda
<
nv_bfloat16
,
nv_bfloat16
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_flt_cuda
<
nv_bfloat16
,
nv_bfloat16
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_BF16
&&
src1
->
type
==
GGML_TYPE_F16
)
{
ggml_cpy_flt_cuda
<
nv_bfloat16
,
half
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_flt_cuda
<
nv_bfloat16
,
half
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_BF16
&&
src1
->
type
==
GGML_TYPE_F32
)
{
ggml_cpy_flt_cuda
<
nv_bfloat16
,
float
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_flt_cuda
<
nv_bfloat16
,
float
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_I32
)
{
ggml_cpy_flt_cuda
<
float
,
int32_t
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_flt_cuda
<
float
,
int32_t
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
if
(
src0
->
type
==
GGML_TYPE_I32
&&
src1
->
type
==
GGML_TYPE_F32
)
{
ggml_cpy_flt_cuda
<
int32_t
,
float
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
,
dest_ptrs_d
,
graph_cpynode_index
);
ggml_cpy_flt_cuda
<
int32_t
,
float
>
(
src0_ddc
,
src1_ddc
,
ne
,
ne00
,
ne01
,
ne02
,
nb00
,
nb01
,
nb02
,
nb03
,
ne10
,
ne11
,
ne12
,
nb10
,
nb11
,
nb12
,
nb13
,
main_stream
);
}
else
{
GGML_ABORT
(
"%s: unsupported type combination (%s to %s)
\n
"
,
__func__
,
ggml_type_name
(
src0
->
type
),
ggml_type_name
(
src1
->
type
));
}
#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS)
if
(
ctx
.
cuda_graph
->
use_cpy_indirection
&&
!
disable_indirection_for_this_node
)
{
ctx
.
cuda_graph
->
graph_cpynode_index
=
graph_cpynode_index
;
}
#else
GGML_UNUSED
(
disable_indirection_for_this_node
);
#endif
}
void
ggml_cuda_dup
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
src0
=
dst
->
src
[
0
];
bool
disable_indirection
=
true
;
ggml_cuda_cpy
(
ctx
,
src0
,
dst
,
disable_indirection
);
}
void
*
ggml_cuda_cpy_fn
(
const
ggml_tensor
*
src0
,
ggml_tensor
*
src1
)
{
if
(
src0
->
type
==
src1
->
type
&&
ggml_is_contiguous
(
src0
)
&&
ggml_is_contiguous
(
src1
))
{
// Prioritize CUDA graph compatibility over direct memory copy optimization.
// Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs.
if
(
src0
->
type
==
GGML_TYPE_F32
)
{
return
(
void
*
)
cpy_flt
<
cpy_1_flt
<
float
,
float
>>
;
}
else
{
return
nullptr
;
}
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_F32
)
{
return
(
void
*
)
cpy_flt
<
cpy_1_flt
<
float
,
float
>>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_BF16
)
{
return
(
void
*
)
cpy_flt
<
cpy_1_flt
<
float
,
nv_bfloat16
>>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_F16
)
{
return
(
void
*
)
cpy_flt
<
cpy_1_flt
<
float
,
half
>>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q8_0
)
{
return
(
void
*
)
cpy_f32_q
<
cpy_blck_f32_q8_0
,
QK8_0
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_Q8_0
&&
src1
->
type
==
GGML_TYPE_F32
)
{
return
(
void
*
)
cpy_q_f32
<
cpy_blck_q8_0_f32
,
QK8_0
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q4_0
)
{
return
(
void
*
)
cpy_f32_q
<
cpy_blck_f32_q4_0
,
QK4_0
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_Q4_0
&&
src1
->
type
==
GGML_TYPE_F32
)
{
return
(
void
*
)
cpy_q_f32
<
cpy_blck_q_f32
<
dequantize_q4_0
,
QK4_0
>
,
QK4_0
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q4_1
)
{
return
(
void
*
)
cpy_f32_q
<
cpy_blck_f32_q4_1
,
QK4_1
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_Q4_1
&&
src1
->
type
==
GGML_TYPE_F32
)
{
return
(
void
*
)
cpy_q_f32
<
cpy_blck_q_f32
<
dequantize_q4_1
,
QK4_1
>
,
QK4_1
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q5_0
)
{
return
(
void
*
)
cpy_f32_q
<
cpy_blck_f32_q5_0
,
QK5_0
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_Q5_0
&&
src1
->
type
==
GGML_TYPE_F32
)
{
return
(
void
*
)
cpy_q_f32
<
cpy_blck_q_f32
<
dequantize_q5_0
,
QK5_0
>
,
QK5_0
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_IQ4_NL
)
{
return
(
void
*
)
cpy_f32_q
<
cpy_blck_f32_iq4_nl
,
QK4_NL
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_Q5_1
)
{
return
(
void
*
)
cpy_f32_q
<
cpy_blck_f32_q5_1
,
QK5_1
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_Q5_1
&&
src1
->
type
==
GGML_TYPE_F32
)
{
return
(
void
*
)
cpy_q_f32
<
cpy_blck_q_f32
<
dequantize_q5_1
,
QK5_1
>
,
QK5_1
>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F16
&&
src1
->
type
==
GGML_TYPE_F16
)
{
return
(
void
*
)
cpy_flt
<
cpy_1_flt
<
half
,
half
>>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F16
&&
src1
->
type
==
GGML_TYPE_BF16
)
{
return
(
void
*
)
cpy_flt
<
cpy_1_flt
<
half
,
nv_bfloat16
>>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F16
&&
src1
->
type
==
GGML_TYPE_F32
)
{
return
(
void
*
)
cpy_flt
<
cpy_1_flt
<
half
,
float
>>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_BF16
&&
src1
->
type
==
GGML_TYPE_F16
)
{
return
(
void
*
)
cpy_flt
<
cpy_1_flt
<
nv_bfloat16
,
half
>>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_BF16
&&
src1
->
type
==
GGML_TYPE_BF16
)
{
return
(
void
*
)
cpy_flt
<
cpy_1_flt
<
nv_bfloat16
,
nv_bfloat16
>>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_BF16
&&
src1
->
type
==
GGML_TYPE_F32
)
{
return
(
void
*
)
cpy_flt
<
cpy_1_flt
<
nv_bfloat16
,
float
>>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_F32
&&
src1
->
type
==
GGML_TYPE_I32
)
{
return
(
void
*
)
cpy_flt
<
cpy_1_flt
<
float
,
int32_t
>>
;
}
else
if
(
src0
->
type
==
GGML_TYPE_I32
&&
src1
->
type
==
GGML_TYPE_F32
)
{
return
(
void
*
)
cpy_flt
<
cpy_1_flt
<
int32_t
,
float
>>
;
}
else
{
GGML_ABORT
(
"%s: unsupported type combination (%s to %s)
\n
"
,
__func__
,
ggml_type_name
(
src0
->
type
),
ggml_type_name
(
src1
->
type
));
}
ggml_cuda_cpy
(
ctx
,
src0
,
dst
);
}
ml/backend/ggml/ggml/src/ggml-cuda/cpy.cuh
View file @
544b6739
...
...
@@ -2,10 +2,6 @@
#define CUDA_CPY_BLOCK_SIZE 64
void
ggml_cuda_cpy
(
ggml_backend_cuda_context
&
ctx
,
const
ggml_tensor
*
src0
,
ggml_tensor
*
src1
,
bool
disable_indirection
=
false
);
void
ggml_cuda_cpy
(
ggml_backend_cuda_context
&
ctx
,
const
ggml_tensor
*
src0
,
ggml_tensor
*
src1
);
void
ggml_cuda_dup
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
);
void
*
ggml_cuda_cpy_fn
(
const
ggml_tensor
*
src0
,
ggml_tensor
*
src1
);
void
ggml_cuda_cpy_dest_ptrs_copy
(
ggml_cuda_graph
*
cuda_graph
,
char
**
host_dest_ptrs
,
const
int
host_dest_ptrs_size
,
cudaStream_t
stream
);
ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh
View file @
544b6739
...
...
@@ -895,6 +895,7 @@ void launch_fattn(
const
dim3
block_dim
(
warp_size
,
nwarps
,
1
);
int
max_blocks_per_sm
=
1
;
// Max. number of active blocks limited by occupancy.
CUDA_CHECK
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_blocks_per_sm
,
fattn_kernel
,
block_dim
.
x
*
block_dim
.
y
*
block_dim
.
z
,
nbytes_shared
));
GGML_ASSERT
(
max_blocks_per_sm
>
0
);
int
parallel_blocks
=
max_blocks_per_sm
;
dim3
blocks_num
;
...
...
ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh
View file @
544b6739
...
...
@@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm
const
int
nthreads
=
ggml_cuda_fattn_vec_get_nthreads_host
(
cc
);
const
int
nwarps
=
nthreads
/
WARP_SIZE
;
fattn_kernel_t
fattn_kernel
=
flash_attn_ext_vec
<
D
,
cols_per_block
,
type_K
,
type_V
,
use_logit_softcap
>
;
const
expr
bool
need_f16_K
=
false
;
const
expr
bool
need_f16_V
=
false
;
const
bool
need_f16_K
=
type_K
==
GGML_TYPE_F16
;
const
bool
need_f16_V
=
type_V
==
GGML_TYPE_F16
;
constexpr
size_t
nbytes_shared
=
0
;
launch_fattn
<
D
,
cols_per_block
,
1
>
(
ctx
,
dst
,
fattn_kernel
,
nwarps
,
nbytes_shared
,
D
,
need_f16_K
,
need_f16_V
,
false
);
}
...
...
@@ -526,11 +526,6 @@ template <int D, ggml_type type_K, ggml_type type_V>
void
ggml_cuda_flash_attn_ext_vec_case
(
ggml_backend_cuda_context
&
ctx
,
ggml_tensor
*
dst
)
{
const
ggml_tensor
*
KQV
=
dst
;
const
ggml_tensor
*
Q
=
dst
->
src
[
0
];
const
ggml_tensor
*
K
=
dst
->
src
[
1
];
const
ggml_tensor
*
V
=
dst
->
src
[
2
];
GGML_ASSERT
(
K
->
type
==
type_K
);
GGML_ASSERT
(
V
->
type
==
type_V
);
float
logit_softcap
;
memcpy
(
&
logit_softcap
,
(
const
float
*
)
KQV
->
op_params
+
2
,
sizeof
(
float
));
...
...
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
View file @
544b6739
...
...
@@ -117,10 +117,14 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
}
#define FATTN_VEC_CASE(D, type_K, type_V) \
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
{ \
const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
return; \
} \
} \
#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
FATTN_VEC_CASE( 64, type_K, type_V) \
...
...
@@ -247,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
#endif // GGML_CUDA_FA_ALL_QUANTS
switch
(
K
->
type
)
{
case
GGML_TYPE_F32
:
case
GGML_TYPE_F16
:
break
;
case
GGML_TYPE_Q4_1
:
...
...
@@ -272,7 +277,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
// If Turing tensor cores available, use them:
if
(
turing_mma_available
(
cc
)
&&
K
->
ne
[
1
]
%
FATTN_KQ_STRIDE
==
0
&&
Q
->
ne
[
0
]
!=
40
)
{
if
(
can_use_vector_kernel
)
{
if
(
K
->
type
==
GGML_TYPE_F16
&&
V
->
type
==
GGML_TYPE_F16
)
{
if
(
!
ggml_is_quantized
(
K
->
type
)
&&
!
ggml_is_quantized
(
V
->
type
)
)
{
if
(
cc
>=
GGML_CUDA_CC_ADA_LOVELACE
&&
Q
->
ne
[
1
]
==
1
&&
Q
->
ne
[
3
]
==
1
&&
!
(
gqa_ratio
>
4
&&
K
->
ne
[
1
]
>=
8192
))
{
return
BEST_FATTN_KERNEL_VEC
;
}
...
...
@@ -305,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
// If there are no tensor cores available, use the generic tile kernel:
if
(
can_use_vector_kernel
)
{
if
(
K
->
type
==
GGML_TYPE_F16
&&
V
->
type
==
GGML_TYPE_F16
)
{
if
(
!
ggml_is_quantized
(
K
->
type
)
&&
!
ggml_is_quantized
(
V
->
type
)
)
{
if
(
Q
->
ne
[
1
]
==
1
)
{
if
(
!
gqa_opt_applies
)
{
return
BEST_FATTN_KERNEL_VEC
;
...
...
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
View file @
544b6739
...
...
@@ -2774,11 +2774,10 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
}
#ifdef USE_CUDA_GRAPH
static
bool
check_node_graph_compatibility
_and_refresh_copy_ops
(
ggml_backend_cuda_context
*
cuda_ctx
,
ggml_cgraph
*
cgraph
,
static
bool
check_node_graph_compatibility
(
ggml_cgraph
*
cgraph
,
int
batch_size
,
bool
use_cuda_graph
)
{
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx
->
cuda_graph
->
cpy_dest_ptrs
.
clear
();
const
std
::
string
gemma3n_per_layer_proj_src0_name
=
"inp_per_layer_selected"
;
const
std
::
string
gemma3n_per_layer_proj_src1_name
=
"per_layer_proj"
;
...
...
@@ -2839,33 +2838,11 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
}
}
if
(
node
->
op
==
GGML_OP_CPY
)
{
// Store the pointers which are updated for each token, such that these can be sent
// to the device and accessed using indirection from CUDA graph
cuda_ctx
->
cuda_graph
->
cpy_dest_ptrs
.
push_back
((
char
*
)
node
->
src
[
1
]
->
data
);
// store a pointer to each copy op CUDA kernel to identify it later
void
*
ptr
=
ggml_cuda_cpy_fn
(
node
->
src
[
0
],
node
->
src
[
1
]);
if
(
!
ptr
)
{
use_cuda_graph
=
false
;
#ifndef NDEBUG
GGML_LOG_DEBUG
(
"%s: disabling CUDA graphs due to unsupported copy op
\n
"
,
__func__
);
#endif
}
}
if
(
!
use_cuda_graph
)
{
break
;
}
}
if
(
use_cuda_graph
)
{
cuda_ctx
->
cuda_graph
->
use_cpy_indirection
=
true
;
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
ggml_cuda_cpy_dest_ptrs_copy
(
cuda_ctx
->
cuda_graph
.
get
(),
cuda_ctx
->
cuda_graph
->
cpy_dest_ptrs
.
data
(),
cuda_ctx
->
cuda_graph
->
cpy_dest_ptrs
.
size
(),
cuda_ctx
->
stream
());
}
return
use_cuda_graph
;
}
...
...
@@ -2884,7 +2861,6 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
static
bool
ggml_graph_node_has_matching_properties
(
ggml_tensor
*
node
,
ggml_graph_node_properties
*
graph_node_properties
)
{
if
(
node
->
data
!=
graph_node_properties
->
node_address
&&
node
->
op
!=
GGML_OP_CPY
&&
node
->
op
!=
GGML_OP_VIEW
)
{
return
false
;
}
...
...
@@ -2905,7 +2881,6 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
for
(
int
i
=
0
;
i
<
GGML_MAX_SRC
;
i
++
)
{
if
(
node
->
src
[
i
]
&&
node
->
src
[
i
]
->
data
!=
graph_node_properties
->
src_address
[
i
]
&&
node
->
op
!=
GGML_OP_CPY
&&
node
->
op
!=
GGML_OP_VIEW
)
{
return
false
;
...
...
@@ -2985,18 +2960,15 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
#endif
//TODO: remove special case once ggml_can_fuse can handle empty nodes
std
::
initializer_list
<
enum
ggml_op
>
topk_moe_ops
=
ggml_cuda_topk_moe_ops
(
false
);
std
::
initializer_list
<
enum
ggml_op
>
topk_moe_ops_with_norm
=
ggml_cuda_topk_moe_ops
(
true
);
if
(
ops
.
size
()
==
topk_moe_ops_with_norm
.
size
()
&&
std
::
equal
(
ops
.
begin
(),
ops
.
end
(),
topk_moe_ops_with_norm
.
begin
()))
{
if
(
node_idx
+
topk_moe_ops_with_norm
.
size
()
>
(
size_t
)
cgraph
->
n_nodes
)
{
return
false
;
}
for
(
size_t
i
=
0
;
i
<
topk_moe_ops_with_norm
.
size
();
i
++
)
{
if
(
cgraph
->
nodes
[
node_idx
+
i
]
->
op
!=
topk_moe_ops_with_norm
.
begin
()[
i
])
return
false
;
}
std
::
initializer_list
<
enum
ggml_op
>
topk_moe_ops
=
ggml_cuda_topk_moe_ops
(
/*with_norm*/
false
,
/*delayed_softmax=*/
false
);
std
::
initializer_list
<
enum
ggml_op
>
topk_moe_ops_with_norm
=
ggml_cuda_topk_moe_ops
(
/*with_norm=*/
true
,
/*delayed_softmax=*/
false
);
std
::
initializer_list
<
enum
ggml_op
>
topk_moe_ops_delayed_softmax
=
ggml_cuda_topk_moe_ops
(
/*with_norm=*/
false
,
/*delayed_softmax=*/
true
);
if
(
ops
.
size
()
==
topk_moe_ops_with_norm
.
size
()
&&
ggml_can_fuse_subgraph
(
cgraph
,
node_idx
,
ops
,
{
node_idx
+
3
,
node_idx
+
8
}))
{
ggml_tensor
*
softmax
=
cgraph
->
nodes
[
node_idx
];
ggml_tensor
*
weights
=
cgraph
->
nodes
[
node_idx
+
8
];
...
...
@@ -3005,18 +2977,20 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if
(
ops
.
size
()
==
topk_moe_ops
.
size
()
&&
std
::
equal
(
ops
.
begin
(),
ops
.
end
(),
topk_moe_ops
.
begin
()))
{
if
(
node_idx
+
topk_moe_ops
.
size
()
>
(
size_t
)
cgraph
->
n_nodes
)
{
return
false
;
if
(
ops
.
size
()
==
topk_moe_ops
.
size
()
&&
ggml_can_fuse_subgraph
(
cgraph
,
node_idx
,
ops
,
{
node_idx
+
3
,
node_idx
+
4
}))
{
ggml_tensor
*
softmax
=
cgraph
->
nodes
[
node_idx
];
ggml_tensor
*
weights
=
cgraph
->
nodes
[
node_idx
+
4
];
if
(
ggml_cuda_should_use_topk_moe
(
softmax
,
weights
))
{
return
true
;
}
for
(
size_t
i
=
0
;
i
<
topk_moe_ops
.
size
();
i
++
)
{
if
(
cgraph
->
nodes
[
node_idx
+
i
]
->
op
!=
topk_moe_ops
.
begin
()[
i
])
return
false
;
}
ggml_tensor
*
softmax
=
cgraph
->
nodes
[
node_idx
];
ggml_tensor
*
weights
=
cgraph
->
nodes
[
node_idx
+
4
];
if
(
ops
.
size
()
==
topk_moe_ops_delayed_softmax
.
size
()
&&
ggml_can_fuse_subgraph
(
cgraph
,
node_idx
,
ops
,
{
node_idx
+
2
,
node_idx
+
5
}))
{
ggml_tensor
*
softmax
=
cgraph
->
nodes
[
node_idx
+
4
];
ggml_tensor
*
weights
=
cgraph
->
nodes
[
node_idx
+
5
];
if
(
ggml_cuda_should_use_topk_moe
(
softmax
,
weights
))
{
return
true
;
}
...
...
@@ -3052,7 +3026,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
//if rms norm is the B operand, then we don't handle broadcast
if
(
rms_norm
==
mul
->
src
[
1
]
&&
!
ggml_are_same_shape
(
mul
->
src
[
0
],
rms_norm
->
src
[
1
]
))
{
if
(
rms_norm
==
mul
->
src
[
1
]
&&
!
ggml_are_same_shape
(
mul
->
src
[
0
],
rms_norm
))
{
return
false
;
}
...
...
@@ -3121,7 +3095,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
if
(
ggml_cuda_can_fuse
(
cgraph
,
i
,
ggml_cuda_topk_moe_ops
(
/*with norm*/
true
),
{}))
{
ggml_tensor
*
weights
=
cgraph
->
nodes
[
i
+
8
];
ggml_tensor
*
selected_experts
=
cgraph
->
nodes
[
i
+
3
];
ggml_cuda_op_topk_moe
(
*
cuda_ctx
,
node
,
weights
,
selected_experts
,
/*with norm*/
true
);
ggml_cuda_op_topk_moe
(
*
cuda_ctx
,
node
->
src
[
0
],
weights
,
selected_experts
,
/*with norm*/
true
,
/*delayed softmax*/
false
);
i
+=
8
;
continue
;
}
...
...
@@ -3129,11 +3104,23 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
if
(
ggml_cuda_can_fuse
(
cgraph
,
i
,
ggml_cuda_topk_moe_ops
(
/*with norm*/
false
),
{}))
{
ggml_tensor
*
weights
=
cgraph
->
nodes
[
i
+
4
];
ggml_tensor
*
selected_experts
=
cgraph
->
nodes
[
i
+
3
];
ggml_cuda_op_topk_moe
(
*
cuda_ctx
,
node
,
weights
,
selected_experts
,
/*with norm*/
false
);
ggml_cuda_op_topk_moe
(
*
cuda_ctx
,
node
->
src
[
0
],
weights
,
selected_experts
,
/*with norm*/
false
,
/*delayed softmax*/
false
);
i
+=
4
;
continue
;
}
if
(
ggml_cuda_can_fuse
(
cgraph
,
i
,
ggml_cuda_topk_moe_ops
(
/*with norm*/
false
,
/*delayed softmax*/
true
),
{}))
{
ggml_tensor
*
weights
=
cgraph
->
nodes
[
i
+
5
];
ggml_tensor
*
ids
=
cgraph
->
nodes
[
i
+
1
];
ggml_cuda_op_topk_moe
(
*
cuda_ctx
,
node
->
src
[
0
],
weights
,
ids
,
/*with norm*/
false
,
/*delayed_softmax*/
true
);
i
+=
5
;
continue
;
}
if
(
node
->
op
==
GGML_OP_ADD
)
{
int
n_fuse
=
0
;
ggml_op
ops
[
8
];
...
...
@@ -3278,7 +3265,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
if
(
use_cuda_graph
)
{
cuda_graph_update_required
=
is_cuda_graph_update_required
(
cuda_ctx
,
cgraph
);
use_cuda_graph
=
check_node_graph_compatibility
_and_refresh_copy_ops
(
cuda_ctx
,
cgraph
,
batch_size
,
use_cuda_graph
);
use_cuda_graph
=
check_node_graph_compatibility
(
cgraph
,
batch_size
,
use_cuda_graph
);
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if
(
use_cuda_graph
&&
cuda_graph_update_required
)
{
...
...
@@ -3305,10 +3292,6 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
CUDA_CHECK
(
cudaStreamBeginCapture
(
cuda_ctx
->
stream
(),
cudaStreamCaptureModeRelaxed
));
}
if
(
!
use_cuda_graph
)
{
cuda_ctx
->
cuda_graph
->
use_cpy_indirection
=
false
;
}
#else
bool
use_cuda_graph
=
false
;
bool
cuda_graph_update_required
=
false
;
...
...
@@ -3922,12 +3905,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case
GGML_OP_CONV_2D_DW
:
case
GGML_OP_CONV_TRANSPOSE_2D
:
case
GGML_OP_POOL_2D
:
case
GGML_OP_SUM
:
case
GGML_OP_ACC
:
return
true
;
case
GGML_OP_SUM
:
return
ggml_is_contiguous_rows
(
op
->
src
[
0
]);
case
GGML_OP_ARGSORT
:
// TODO: Support arbitrary column width
#ifndef GGML_CUDA_USE_CUB
return
op
->
src
[
0
]
->
ne
[
0
]
<=
1024
;
#else
return
true
;
#endif
case
GGML_OP_SUM_ROWS
:
case
GGML_OP_MEAN
:
case
GGML_OP_GROUP_NORM
:
...
...
ml/backend/ggml/ggml/src/ggml-cuda/mmf.cu
View file @
544b6739
#include "ggml.h"
#include "mmf.cuh"
#include "mmid.cuh"
void
ggml_cuda_mul_mat_f
(
ggml_backend_cuda_context
&
ctx
,
const
ggml_tensor
*
src0
,
const
ggml_tensor
*
src1
,
const
ggml_tensor
*
ids
,
ggml_tensor
*
dst
)
{
GGML_ASSERT
(
src1
->
type
==
GGML_TYPE_F32
);
...
...
@@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
const
int64_t
ids_s0
=
ids
?
ids
->
nb
[
0
]
/
ggml_type_size
(
ids
->
type
)
:
0
;
const
int64_t
ids_s1
=
ids
?
ids
->
nb
[
1
]
/
ggml_type_size
(
ids
->
type
)
:
0
;
mmf_ids_data
ids_info
{};
mmf_ids_data
*
ids_info_ptr
=
nullptr
;
ggml_cuda_pool_alloc
<
int32_t
>
ids_src_compact_dev
;
ggml_cuda_pool_alloc
<
int32_t
>
ids_dst_compact_dev
;
ggml_cuda_pool_alloc
<
int32_t
>
expert_bounds_dev
;
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const
int64_t
ncols_dst
=
ids
?
ne2
:
ne1
;
const
int64_t
nchannels_dst
=
ids
?
ne1
:
ne2
;
...
...
@@ -54,6 +62,33 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
nchannels_y
=
ids
->
ne
[
0
];
}
if
(
ids
&&
ncols_dst
>
16
)
{
const
int64_t
n_expert_used
=
ids
->
ne
[
0
];
const
int64_t
n_experts
=
ne02
;
const
int64_t
n_tokens
=
ne12
;
const
int64_t
ne_get_rows
=
n_tokens
*
n_expert_used
;
ids_src_compact_dev
.
alloc
(
ctx
.
pool
(),
ne_get_rows
);
ids_dst_compact_dev
.
alloc
(
ctx
.
pool
(),
ne_get_rows
);
expert_bounds_dev
.
alloc
(
ctx
.
pool
(),
n_experts
+
1
);
const
int
si1
=
static_cast
<
int
>
(
ids_s1
);
const
int
sis1
=
static_cast
<
int
>
(
src1
->
nb
[
2
]
/
src1
->
nb
[
1
]);
GGML_ASSERT
(
sis1
>
0
);
ggml_cuda_launch_mm_ids_helper
(
ids_d
,
ids_src_compact_dev
.
get
(),
ids_dst_compact_dev
.
get
(),
expert_bounds_dev
.
get
(),
static_cast
<
int
>
(
n_experts
),
static_cast
<
int
>
(
n_tokens
),
static_cast
<
int
>
(
n_expert_used
),
static_cast
<
int
>
(
ne11
),
si1
,
sis1
,
ctx
.
stream
());
CUDA_CHECK
(
cudaGetLastError
());
ids_info
.
ids_src_compact
=
ids_src_compact_dev
.
get
();
ids_info
.
ids_dst_compact
=
ids_dst_compact_dev
.
get
();
ids_info
.
expert_bounds_dev
=
expert_bounds_dev
.
get
();
ids_info
.
n_experts
=
static_cast
<
int
>
(
n_experts
);
ids_info
.
sis1
=
sis1
;
ids_info_ptr
=
&
ids_info
;
}
switch
(
src0
->
type
)
{
case
GGML_TYPE_F32
:
{
const
float
*
src0_d
=
(
const
float
*
)
src0
->
data
;
...
...
@@ -61,7 +96,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
mul_mat_f_switch_cols_per_block
(
src0_d
,
src1_d
,
ids_d
,
dst_d
,
ne00
/
vals_per_T
,
ne01
,
ncols_dst
,
s01
/
vals_per_T
,
stride_col_y
/
vals_per_T
,
stride_col_dst
,
ids_s0
,
ids_s1
,
ne02
,
nchannels_y
,
nchannels_dst
,
s02
/
vals_per_T
,
stride_channel_y
,
stride_channel_dst
,
ne03
,
ne3
,
s03
/
vals_per_T
,
s13
,
s3
,
ctx
.
stream
());
ne03
,
ne3
,
s03
/
vals_per_T
,
s13
,
s3
,
ctx
.
stream
()
,
ids_info_ptr
);
}
break
;
case
GGML_TYPE_F16
:
{
const
half2
*
src0_d
=
(
const
half2
*
)
src0
->
data
;
...
...
@@ -69,7 +104,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
mul_mat_f_switch_cols_per_block
(
src0_d
,
src1_d
,
ids_d
,
dst_d
,
ne00
/
vals_per_T
,
ne01
,
ncols_dst
,
s01
/
vals_per_T
,
stride_col_y
/
vals_per_T
,
stride_col_dst
,
ids_s0
,
ids_s1
,
ne02
,
nchannels_y
,
nchannels_dst
,
s02
/
vals_per_T
,
stride_channel_y
,
stride_channel_dst
,
ne03
,
ne3
,
s03
/
vals_per_T
,
s13
,
s3
,
ctx
.
stream
());
ne03
,
ne3
,
s03
/
vals_per_T
,
s13
,
s3
,
ctx
.
stream
()
,
ids_info_ptr
);
}
break
;
case
GGML_TYPE_BF16
:
{
const
nv_bfloat162
*
src0_d
=
(
const
nv_bfloat162
*
)
src0
->
data
;
...
...
@@ -77,7 +112,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
mul_mat_f_switch_cols_per_block
(
src0_d
,
src1_d
,
ids_d
,
dst_d
,
ne00
/
vals_per_T
,
ne01
,
ncols_dst
,
s01
/
vals_per_T
,
stride_col_y
/
vals_per_T
,
stride_col_dst
,
ids_s0
,
ids_s1
,
ne02
,
nchannels_y
,
nchannels_dst
,
s02
/
vals_per_T
,
stride_channel_y
,
stride_channel_dst
,
ne03
,
ne3
,
s03
/
vals_per_T
,
s13
,
s3
,
ctx
.
stream
());
ne03
,
ne3
,
s03
/
vals_per_T
,
s13
,
s3
,
ctx
.
stream
()
,
ids_info_ptr
);
}
break
;
default:
GGML_ABORT
(
"unsupported type: %s"
,
ggml_type_name
(
src0
->
type
));
...
...
@@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
}
if
(
mul_mat_id
)
{
if
(
type
==
GGML_TYPE_F32
&&
src1_ncols
>
3
2
)
{
if
(
src0_ne
[
1
]
<=
1024
&&
src1_ncols
>
51
2
)
{
return
false
;
}
if
((
type
==
GGML_TYPE_F16
||
type
==
GGML_TYPE_BF16
)
&&
src1_ncols
>
64
)
{
}
else
if
(
src0_ne
[
1
]
>
1024
&&
src1_ncols
>
128
)
{
return
false
;
}
}
else
{
...
...
ml/backend/ggml/ggml/src/ggml-cuda/mmf.cuh
View file @
544b6739
...
...
@@ -7,6 +7,14 @@ using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
struct
mmf_ids_data
{
const
int32_t
*
ids_src_compact
=
nullptr
;
const
int32_t
*
ids_dst_compact
=
nullptr
;
const
int32_t
*
expert_bounds_dev
=
nullptr
;
int
n_experts
=
0
;
int
sis1
=
0
;
};
void
ggml_cuda_mul_mat_f
(
ggml_backend_cuda_context
&
ctx
,
const
ggml_tensor
*
src0
,
const
ggml_tensor
*
src1
,
const
ggml_tensor
*
ids
,
ggml_tensor
*
dst
);
bool
ggml_cuda_should_use_mmf
(
enum
ggml_type
type
,
int
cc
,
int
warp_size
,
const
int64_t
*
scr0_ne
,
const
int
src1_ncols
,
bool
mul_mat_id
);
...
...
@@ -224,6 +232,250 @@ static __global__ void mul_mat_f(
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
//This kernel is for larger batch sizes of mul_mat_id
template
<
typename
T
,
int
rows_per_block
,
int
cols_per_block
,
int
nwarps
>
__launch_bounds__
(
ggml_cuda_get_physical_warp_size
()
*
nwarps
,
1
)
static
__global__
void
mul_mat_f_ids
(
const
T
*
__restrict__
x
,
const
float
*
__restrict__
y
,
const
int32_t
*
__restrict__
ids_src_compact
,
const
int32_t
*
__restrict__
ids_dst_compact
,
const
int32_t
*
__restrict__
expert_bounds
,
float
*
__restrict__
dst
,
const
int
ncols
,
const
int
ncols_dst_total
,
const
int
nchannels_dst
,
const
int
stride_row
,
const
int
stride_col_y
,
const
int
stride_col_dst
,
const
int
channel_ratio
,
const
int
stride_channel_x
,
const
int
stride_channel_y
,
const
int
stride_channel_dst
,
const
int
sample_ratio
,
const
int
stride_sample_x
,
const
int
stride_sample_y
,
const
int
stride_sample_dst
,
const
uint3
sis1_fd
,
const
uint3
nch_fd
)
{
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
typedef
tile
<
16
,
8
,
T
>
tile_A
;
typedef
tile
<
8
,
8
,
T
>
tile_B
;
typedef
tile
<
16
,
8
,
float
>
tile_C
;
constexpr
int
warp_size
=
ggml_cuda_get_physical_warp_size
();
constexpr
int
tile_k_padded
=
warp_size
+
4
;
constexpr
int
ntA
=
rows_per_block
/
tile_A
::
I
;
constexpr
int
ntB
=
(
cols_per_block
+
tile_B
::
I
-
1
)
/
tile_B
::
I
;
const
int
row0
=
blockIdx
.
x
*
rows_per_block
;
const
int
expert_idx
=
blockIdx
.
y
;
const
int
expert_start
=
expert_bounds
[
expert_idx
];
const
int
expert_end
=
expert_bounds
[
expert_idx
+
1
];
const
int
ncols_expert
=
expert_end
-
expert_start
;
const
int
tiles_for_expert
=
(
ncols_expert
+
cols_per_block
-
1
)
/
cols_per_block
;
const
int
tile_idx
=
blockIdx
.
z
;
if
(
tile_idx
>=
tiles_for_expert
)
{
return
;
}
const
int
col_base
=
tile_idx
*
cols_per_block
;
GGML_UNUSED
(
channel_ratio
);
const
int
channel_x
=
expert_idx
;
const
int
sample_dst
=
0
;
const
int
sample_x
=
sample_dst
/
sample_ratio
;
const
int
sample_y
=
sample_dst
;
x
+=
int64_t
(
sample_x
)
*
stride_sample_x
+
channel_x
*
stride_channel_x
+
row0
*
stride_row
;
y
+=
int64_t
(
sample_y
)
*
stride_sample_y
;
dst
+=
int64_t
(
sample_dst
)
*
stride_sample_dst
;
const
int32_t
*
ids_src_expert
=
ids_src_compact
+
expert_start
;
const
int32_t
*
ids_dst_expert
=
ids_dst_compact
+
expert_start
;
extern
__shared__
char
data_mmv
[];
char
*
compute_base
=
data_mmv
;
//const float2 * y2 = (const float2 *) y;
tile_C
C
[
ntA
][
ntB
];
T
*
tile_xy
=
(
T
*
)
compute_base
+
threadIdx
.
y
*
(
tile_A
::
I
*
tile_k_padded
);
for
(
int
col
=
threadIdx
.
y
*
warp_size
+
threadIdx
.
x
;
col
<
ncols
;
col
+=
nwarps
*
warp_size
)
{
tile_A
A
[
ntA
][
warp_size
/
tile_A
::
J
];
#pragma unroll
for
(
int
itA
=
0
;
itA
<
ntA
;
++
itA
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
tile_A
::
I
;
++
i
)
{
tile_xy
[
i
*
tile_k_padded
+
threadIdx
.
x
]
=
x
[(
itA
*
tile_A
::
I
+
i
)
*
stride_row
+
col
];
}
#pragma unroll
for
(
int
k0
=
0
;
k0
<
warp_size
;
k0
+=
tile_A
::
J
)
{
load_ldmatrix
(
A
[
itA
][
k0
/
tile_A
::
J
],
tile_xy
+
k0
,
tile_k_padded
);
}
}
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
float
vals_buf
[
2
][
tile_B
::
I
];
auto
gather_tile
=
[
&
](
int
tile_idx_local
,
float
*
vals
)
{
#pragma unroll
for
(
int
j0
=
0
;
j0
<
tile_B
::
I
;
++
j0
)
{
const
int
j
=
j0
+
tile_idx_local
*
tile_B
::
I
;
const
int
global_j
=
col_base
+
j
;
float
val
=
0.0
f
;
if
(
j
<
cols_per_block
&&
global_j
<
ncols_expert
)
{
const
int
src_entry
=
ids_src_expert
[
global_j
];
const
uint2
qrm
=
fast_div_modulo
((
uint32_t
)
src_entry
,
sis1_fd
);
const
int
token
=
(
int
)
qrm
.
x
;
const
int
channel
=
(
int
)
qrm
.
y
;
if
(
token
<
ncols_dst_total
)
{
val
=
y
[
channel
*
stride_channel_y
+
token
*
stride_col_y
+
col
];
}
}
vals
[
j0
]
=
val
;
}
};
gather_tile
(
0
,
vals_buf
[
0
]);
int
curr_buf
=
0
;
int
next_buf
=
1
;
#pragma unroll
for
(
int
itB
=
0
;
itB
<
ntB
;
++
itB
)
{
#pragma unroll
for
(
int
j0
=
0
;
j0
<
tile_B
::
I
;
++
j0
)
{
tile_xy
[
j0
*
tile_k_padded
+
threadIdx
.
x
]
=
vals_buf
[
curr_buf
][
j0
];
}
if
(
itB
+
1
<
ntB
)
{
gather_tile
(
itB
+
1
,
vals_buf
[
next_buf
]);
}
#pragma unroll
for
(
int
k0
=
0
;
k0
<
warp_size
;
k0
+=
tile_B
::
J
)
{
tile_B
B
;
load_ldmatrix
(
B
,
tile_xy
+
k0
,
tile_k_padded
);
#pragma unroll
for
(
int
itA
=
0
;
itA
<
ntA
;
++
itA
)
{
mma
(
C
[
itA
][
itB
],
A
[
itA
][
k0
/
tile_B
::
J
],
B
);
}
}
if
(
itB
+
1
<
ntB
)
{
curr_buf
^=
1
;
next_buf
^=
1
;
}
}
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half2
>
||
std
::
is_same_v
<
T
,
nv_bfloat162
>
)
{
float2
vals_buf
[
2
][
tile_B
::
I
];
auto
gather_tile
=
[
&
](
int
tile_idx_local
,
float2
*
vals
)
{
#pragma unroll
for
(
int
j0
=
0
;
j0
<
tile_B
::
I
;
++
j0
)
{
const
int
j
=
j0
+
tile_idx_local
*
tile_B
::
I
;
const
int
global_j
=
col_base
+
j
;
float2
tmp
=
make_float2
(
0.0
f
,
0.0
f
);
if
(
j
<
cols_per_block
&&
global_j
<
ncols_expert
)
{
const
int
src_entry
=
ids_src_expert
[
global_j
];
const
uint2
qrm
=
fast_div_modulo
((
uint32_t
)
src_entry
,
sis1_fd
);
const
int
token
=
(
int
)
qrm
.
x
;
const
int
channel
=
(
int
)
qrm
.
y
;
if
(
token
<
ncols_dst_total
)
{
tmp
=
*
(
const
float2
*
)
&
y
[
channel
*
stride_channel_y
+
2
*
(
token
*
stride_col_y
+
col
)];
}
}
vals
[
j0
]
=
tmp
;
}
};
if
(
ntB
>
0
)
{
gather_tile
(
0
,
vals_buf
[
0
]);
}
int
curr_buf
=
0
;
int
next_buf
=
1
;
#pragma unroll
for
(
int
itB
=
0
;
itB
<
ntB
;
++
itB
)
{
#pragma unroll
for
(
int
j0
=
0
;
j0
<
tile_B
::
I
;
++
j0
)
{
const
float2
tmp
=
vals_buf
[
curr_buf
][
j0
];
tile_xy
[
j0
*
tile_k_padded
+
threadIdx
.
x
]
=
{
tmp
.
x
,
tmp
.
y
};
}
if
(
itB
+
1
<
ntB
)
{
gather_tile
(
itB
+
1
,
vals_buf
[
next_buf
]);
}
#pragma unroll
for
(
int
k0
=
0
;
k0
<
warp_size
;
k0
+=
tile_B
::
J
)
{
tile_B
B
;
load_ldmatrix
(
B
,
tile_xy
+
k0
,
tile_k_padded
);
#pragma unroll
for
(
int
itA
=
0
;
itA
<
ntA
;
++
itA
)
{
mma
(
C
[
itA
][
itB
],
A
[
itA
][
k0
/
tile_B
::
J
],
B
);
}
}
if
(
itB
+
1
<
ntB
)
{
curr_buf
^=
1
;
next_buf
^=
1
;
}
}
}
else
{
static_assert
(
std
::
is_same_v
<
T
,
void
>
,
"unsupported type"
);
}
}
float
*
buf_iw
=
(
float
*
)
compute_base
;
constexpr
int
kiw
=
nwarps
*
rows_per_block
+
4
;
if
(
nwarps
>
1
)
{
__syncthreads
();
}
#pragma unroll
for
(
int
itB
=
0
;
itB
<
ntB
;
++
itB
)
{
#pragma unroll
for
(
int
itA
=
0
;
itA
<
ntA
;
++
itA
)
{
#pragma unroll
for
(
int
l
=
0
;
l
<
tile_C
::
ne
;
++
l
)
{
const
int
i
=
threadIdx
.
y
*
rows_per_block
+
itA
*
tile_C
::
I
+
tile_C
::
get_i
(
l
);
const
int
j
=
itB
*
tile_C
::
J
+
tile_C
::
get_j
(
l
);
buf_iw
[
j
*
kiw
+
i
]
=
C
[
itA
][
itB
].
x
[
l
];
}
}
}
if
(
nwarps
>
1
)
{
__syncthreads
();
}
#pragma unroll
for
(
int
j0
=
0
;
j0
<
cols_per_block
;
j0
+=
nwarps
)
{
const
int
j
=
j0
+
threadIdx
.
y
;
if
(
j0
+
nwarps
>
cols_per_block
&&
j
>=
cols_per_block
)
{
return
;
}
float
sum
=
0.0
f
;
static_assert
(
rows_per_block
==
warp_size
,
"need loop/check"
);
#pragma unroll
for
(
int
i0
=
0
;
i0
<
nwarps
*
rows_per_block
;
i0
+=
rows_per_block
)
{
const
int
i
=
i0
+
threadIdx
.
x
;
sum
+=
buf_iw
[
j
*
kiw
+
i
];
}
const
int
global_j
=
col_base
+
j
;
if
(
j
<
cols_per_block
&&
global_j
<
ncols_expert
&&
nchannels_dst
>
0
)
{
const
int
dst_entry
=
ids_dst_expert
[
global_j
];
const
uint2
qrm
=
fast_div_modulo
((
uint32_t
)
dst_entry
,
nch_fd
);
const
int
token
=
(
int
)
qrm
.
x
;
if
(
token
<
ncols_dst_total
)
{
const
int
slot
=
(
int
)
qrm
.
y
;
dst
[
slot
*
stride_channel_dst
+
token
*
stride_col_dst
+
row0
+
threadIdx
.
x
]
=
sum
;
}
}
}
#else
GGML_UNUSED_VARS
(
x
,
y
,
ids_src_compact
,
ids_dst_compact
,
expert_bounds
,
dst
,
ncols
,
ncols_dst_total
,
nchannels_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
sis1_fd
,
nch_fd
);
NO_DEVICE_CODE
;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
template
<
typename
T
,
int
cols_per_block
,
int
nwarps
>
static
inline
void
mul_mat_f_switch_ids
(
const
T
*
x
,
const
float
*
y
,
const
int32_t
*
ids
,
float
*
dst
,
...
...
@@ -232,11 +484,33 @@ static inline void mul_mat_f_switch_ids(
const
int64_t
stride_col_id
,
const
int64_t
stride_row_id
,
const
int64_t
channel_ratio
,
const
int64_t
stride_channel_x
,
const
int64_t
stride_channel_y
,
const
int64_t
stride_channel_dst
,
const
int64_t
sample_ratio
,
const
int64_t
stride_sample_x
,
const
int64_t
stride_sample_y
,
const
int64_t
stride_sample_dst
,
const
dim3
&
block_nums
,
const
dim3
&
block_dims
,
const
int
nbytes_shared_total
,
cudaStream_t
stream
)
{
if
(
ids
)
{
const
dim3
&
block_nums
,
const
dim3
&
block_dims
,
const
int
nbytes_shared_total
,
cudaStream_t
stream
,
const
mmf_ids_data
*
ids_data
)
{
const
bool
has_ids_data
=
ids_data
&&
ids_data
->
ids_src_compact
;
// Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
// we prefer the normal mul_mat_f path with has_ids=true.
if
(
has_ids_data
&&
ncols_dst
>
16
)
{
const
int
max_tiles
=
(
int
)
((
ncols_dst
+
cols_per_block
-
1
)
/
cols_per_block
);
if
(
max_tiles
==
0
)
{
return
;
}
dim3
block_nums_ids
(
block_nums
.
x
,
ids_data
->
n_experts
,
max_tiles
);
const
uint3
sis1_fd
=
ids_data
->
sis1
>
0
?
init_fastdiv_values
((
uint32_t
)
ids_data
->
sis1
)
:
make_uint3
(
0
,
0
,
1
);
const
uint3
nch_fd
=
init_fastdiv_values
((
uint32_t
)
nchannels_dst
);
mul_mat_f_ids
<
T
,
MMF_ROWS_PER_BLOCK
,
cols_per_block
,
nwarps
><<<
block_nums_ids
,
block_dims
,
nbytes_shared_total
,
stream
>>>
(
x
,
y
,
ids_data
->
ids_src_compact
,
ids_data
->
ids_dst_compact
,
ids_data
->
expert_bounds_dev
,
dst
,
ncols_x
,
ncols_dst
,
nchannels_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
sis1_fd
,
nch_fd
);
}
else
if
(
ids
)
{
const
int64_t
col_tiles
=
(
ncols_dst
+
cols_per_block
-
1
)
/
cols_per_block
;
dim3
block_nums_ids
=
block_nums
;
block_nums_ids
.
y
*=
col_tiles
;
mul_mat_f
<
T
,
MMF_ROWS_PER_BLOCK
,
cols_per_block
,
nwarps
,
true
><<<
block_nums_ids
,
block_dims
,
nbytes_shared_total
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
ncols_dst
,
nchannels_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
...
...
@@ -258,7 +532,7 @@ void mul_mat_f_cuda(
const
int64_t
nchannels_x
,
const
int64_t
nchannels_y
,
const
int64_t
nchannels_dst
,
const
int64_t
stride_channel_x
,
const
int64_t
stride_channel_y
,
const
int64_t
stride_channel_dst
,
const
int64_t
nsamples_x
,
const
int64_t
nsamples_dst
,
const
int64_t
stride_sample_x
,
const
int64_t
stride_sample_y
,
const
int64_t
stride_sample_dst
,
cudaStream_t
stream
)
{
cudaStream_t
stream
,
const
mmf_ids_data
*
ids_data
)
{
typedef
tile
<
16
,
8
,
T
>
tile_A
;
typedef
tile
<
8
,
8
,
T
>
tile_B
;
...
...
@@ -290,7 +564,7 @@ void mul_mat_f_cuda(
const
int
nbytes_shared
=
std
::
max
(
nbytes_shared_iter
,
nbytes_shared_combine
);
const
int
nbytes_slotmap
=
ids
?
GGML_PAD
(
cols_per_block
,
16
)
*
sizeof
(
int
)
:
0
;
const
int
nbytes_shared_total
=
nbytes_shared
+
nbytes_slotmap
;
const
int64_t
grid_y
=
ids
?
nchannels_x
:
nchannels_dst
;
// per expert when ids present
const
int64_t
grid_y
=
ids
?
nchannels_x
:
nchannels_dst
;
const
dim3
block_nums
(
nrows_x
/
rows_per_block
,
grid_y
,
nsamples_dst
);
const
dim3
block_dims
(
warp_size
,
nwarps_best
,
1
);
...
...
@@ -300,49 +574,57 @@ void mul_mat_f_cuda(
mul_mat_f_switch_ids
<
T
,
cols_per_block
,
1
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
ncols_dst
,
nchannels_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
);
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
,
ids_data
);
}
break
;
case
2
:
{
mul_mat_f_switch_ids
<
T
,
cols_per_block
,
2
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
ncols_dst
,
nchannels_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
);
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
,
ids_data
);
}
break
;
case
3
:
{
mul_mat_f_switch_ids
<
T
,
cols_per_block
,
3
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
ncols_dst
,
nchannels_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
);
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
,
ids_data
);
}
break
;
case
4
:
{
mul_mat_f_switch_ids
<
T
,
cols_per_block
,
4
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
ncols_dst
,
nchannels_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
);
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
,
ids_data
);
}
break
;
case
5
:
{
mul_mat_f_switch_ids
<
T
,
cols_per_block
,
5
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
ncols_dst
,
nchannels_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
);
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
,
ids_data
);
}
break
;
case
6
:
{
mul_mat_f_switch_ids
<
T
,
cols_per_block
,
6
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
ncols_dst
,
nchannels_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
);
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
,
ids_data
);
}
break
;
case
7
:
{
mul_mat_f_switch_ids
<
T
,
cols_per_block
,
7
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
ncols_dst
,
nchannels_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
);
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
,
ids_data
);
}
break
;
case
8
:
{
mul_mat_f_switch_ids
<
T
,
cols_per_block
,
8
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
ncols_dst
,
nchannels_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
);
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
block_nums
,
block_dims
,
nbytes_shared_total
,
stream
,
ids_data
);
}
break
;
default:
{
GGML_ABORT
(
"fatal error"
);
...
...
@@ -361,7 +643,7 @@ static void mul_mat_f_switch_cols_per_block(
const
int64_t
nchannels_x
,
const
int64_t
nchannels_y
,
const
int64_t
nchannels_dst
,
const
int64_t
stride_channel_x
,
const
int64_t
stride_channel_y
,
const
int64_t
stride_channel_dst
,
const
int64_t
nsamples_x
,
const
int64_t
nsamples_dst
,
const
int64_t
stride_sample_x
,
const
int64_t
stride_sample_y
,
const
int64_t
stride_sample_dst
,
cudaStream_t
stream
)
{
cudaStream_t
stream
,
const
mmf_ids_data
*
ids_data
)
{
const
int
ncols_case
=
(
ids
&&
ncols_dst
>
16
)
?
16
:
ncols_dst
;
...
...
@@ -371,82 +653,82 @@ static void mul_mat_f_switch_cols_per_block(
case
1
:
{
mul_mat_f_cuda
<
T
,
1
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
2
:
{
mul_mat_f_cuda
<
T
,
2
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
3
:
{
mul_mat_f_cuda
<
T
,
3
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
4
:
{
mul_mat_f_cuda
<
T
,
4
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
5
:
{
mul_mat_f_cuda
<
T
,
5
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
6
:
{
mul_mat_f_cuda
<
T
,
6
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
7
:
{
mul_mat_f_cuda
<
T
,
7
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
8
:
{
mul_mat_f_cuda
<
T
,
8
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
9
:
{
mul_mat_f_cuda
<
T
,
9
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
10
:
{
mul_mat_f_cuda
<
T
,
10
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
11
:
{
mul_mat_f_cuda
<
T
,
11
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
12
:
{
mul_mat_f_cuda
<
T
,
12
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
13
:
{
mul_mat_f_cuda
<
T
,
13
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
14
:
{
mul_mat_f_cuda
<
T
,
14
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
15
:
{
mul_mat_f_cuda
<
T
,
15
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
case
16
:
{
mul_mat_f_cuda
<
T
,
16
>
(
x
,
y
,
ids
,
dst
,
ncols_x
,
nrows_x
,
ncols_dst
,
stride_row
,
stride_col_y
,
stride_col_dst
,
stride_col_id
,
stride_row_id
,
nchannels_x
,
nchannels_y
,
nchannels_dst
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
);
nsamples_x
,
nsamples_dst
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
,
stream
,
ids_data
);
}
break
;
default:
{
GGML_ABORT
(
"fatal error"
);
...
...
@@ -462,7 +744,7 @@ static void mul_mat_f_switch_cols_per_block(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
cudaStream_t stream);
cudaStream_t stream
, const mmf_ids_data * ids_data
);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
...
...
ml/backend/ggml/ggml/src/ggml-cuda/mmid.cu
0 → 100644
View file @
544b6739
#include "common.cuh"
#include "mmid.cuh"
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
struct
mm_ids_helper_store
{
uint32_t
data
;
__device__
mm_ids_helper_store
(
const
uint32_t
it
,
const
uint32_t
iex_used
)
{
data
=
(
it
&
0x003FFFFF
)
|
(
iex_used
<<
22
);
}
__device__
uint32_t
it
()
const
{
return
data
&
0x003FFFFF
;
}
__device__
uint32_t
iex_used
()
const
{
return
data
>>
22
;
}
};
static_assert
(
sizeof
(
mm_ids_helper_store
)
==
4
,
"unexpected size for mm_ids_helper_store"
);
// Helper function for mul_mat_id, converts ids to a more convenient format.
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
// ids_dst describes the same mapping but for the dst tensor.
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
template
<
int
n_expert_used_template
>
__launch_bounds__
(
ggml_cuda_get_physical_warp_size
(),
1
)
static
__global__
void
mm_ids_helper
(
const
int32_t
*
__restrict__
ids
,
int32_t
*
__restrict__
ids_src1
,
int32_t
*
__restrict__
ids_dst
,
int32_t
*
__restrict__
expert_bounds
,
const
int
n_tokens
,
const
int
n_expert_used_var
,
const
int
nchannels_y
,
const
int
si1
,
const
int
sis1
)
{
constexpr
int
warp_size
=
ggml_cuda_get_physical_warp_size
();
const
int
n_expert_used
=
n_expert_used_template
==
0
?
n_expert_used_var
:
n_expert_used_template
;
const
int
expert
=
blockIdx
.
x
;
extern
__shared__
char
data_mm_ids_helper
[];
mm_ids_helper_store
*
store
=
(
mm_ids_helper_store
*
)
data_mm_ids_helper
;
int
nex_prev
=
0
;
// Number of columns for experts with a lower index.
int
it_compact
=
0
;
// Running index for the compact slice of this expert.
if
constexpr
(
n_expert_used_template
==
0
)
{
// Generic implementation:
for
(
int
it
=
0
;
it
<
n_tokens
;
++
it
)
{
int
iex_used
=
-
1
;
// The index at which the expert is used, if any.
for
(
int
iex
=
threadIdx
.
x
;
iex
<
n_expert_used
;
iex
+=
warp_size
)
{
const
int
expert_used
=
ids
[
it
*
si1
+
iex
];
nex_prev
+=
expert_used
<
expert
;
if
(
expert_used
==
expert
)
{
iex_used
=
iex
;
}
}
if
(
iex_used
!=
-
1
)
{
store
[
it_compact
]
=
mm_ids_helper_store
(
it
,
iex_used
);
}
if
(
warp_reduce_any
<
warp_size
>
(
iex_used
!=
-
1
))
{
it_compact
++
;
}
}
}
else
{
// Implementation optimized for specific numbers of experts used:
static_assert
(
n_expert_used
==
6
||
warp_size
%
n_expert_used
==
0
,
"bad n_expert_used"
);
const
int
neu_padded
=
n_expert_used
==
6
?
8
:
n_expert_used
;
// Padded to next higher power of 2.
for
(
int
it0
=
0
;
it0
<
n_tokens
;
it0
+=
warp_size
/
neu_padded
)
{
const
int
it
=
it0
+
threadIdx
.
x
/
neu_padded
;
const
int
iex
=
threadIdx
.
x
%
neu_padded
;
// The index at which the expert is used, if any.
const
int
expert_used
=
(
neu_padded
==
n_expert_used
||
iex
<
n_expert_used
)
&&
it
<
n_tokens
?
ids
[
it
*
si1
+
iex
]
:
INT_MAX
;
const
int
iex_used
=
expert_used
==
expert
?
iex
:
-
1
;
nex_prev
+=
expert_used
<
expert
;
// Whether the threads at this token position have used the expert:
const
int
it_compact_add_self
=
warp_reduce_any
<
neu_padded
>
(
iex_used
!=
-
1
);
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
int
it_compact_add_lower
=
0
;
#pragma unroll
for
(
int
offset
=
neu_padded
;
offset
<
warp_size
;
offset
+=
neu_padded
)
{
const
int
tmp
=
__shfl_up_sync
(
0xFFFFFFFF
,
it_compact_add_self
,
offset
,
warp_size
);
if
(
threadIdx
.
x
>=
static_cast
<
unsigned
int
>
(
offset
))
{
it_compact_add_lower
+=
tmp
;
}
}
if
(
iex_used
!=
-
1
)
{
store
[
it_compact
+
it_compact_add_lower
]
=
mm_ids_helper_store
(
it
,
iex_used
);
}
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
it_compact
+=
__shfl_sync
(
0xFFFFFFFF
,
it_compact_add_lower
+
it_compact_add_self
,
warp_size
-
1
,
warp_size
);
}
}
nex_prev
=
warp_reduce_sum
<
warp_size
>
(
nex_prev
);
for
(
int
itc
=
threadIdx
.
x
;
itc
<
it_compact
;
itc
+=
warp_size
)
{
const
mm_ids_helper_store
store_it
=
store
[
itc
];
const
int
it
=
store_it
.
it
();
const
int
iex_used
=
store_it
.
iex_used
();
ids_src1
[
nex_prev
+
itc
]
=
it
*
sis1
+
iex_used
%
nchannels_y
;
ids_dst
[
nex_prev
+
itc
]
=
it
*
n_expert_used
+
iex_used
;
}
if
(
threadIdx
.
x
!=
0
)
{
return
;
}
expert_bounds
[
expert
]
=
nex_prev
;
if
(
expert
<
static_cast
<
int
>
(
gridDim
.
x
)
-
1
)
{
return
;
}
expert_bounds
[
gridDim
.
x
]
=
nex_prev
+
it_compact
;
}
template
<
int
n_expert_used_template
>
static
void
launch_mm_ids_helper
(
const
int32_t
*
__restrict__
ids
,
int32_t
*
__restrict__
ids_src1
,
int32_t
*
__restrict__
ids_dst
,
int32_t
*
__restrict__
expert_bounds
,
const
int
n_experts
,
const
int
n_tokens
,
const
int
n_expert_used_var
,
const
int
nchannels_y
,
const
int
si1
,
const
int
sis1
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
n_tokens
<
(
1
<<
22
)
&&
"too few bits in mm_ids_helper_store"
);
GGML_ASSERT
(
n_expert_used_var
<
(
1
<<
10
)
&&
"too few bits in mm_ids_helper_store"
);
const
int
id
=
ggml_cuda_get_device
();
const
int
warp_size
=
ggml_cuda_info
().
devices
[
id
].
warp_size
;
const
size_t
smpbo
=
ggml_cuda_info
().
devices
[
id
].
smpbo
;
CUDA_SET_SHARED_MEMORY_LIMIT
(
mm_ids_helper
<
n_expert_used_template
>
,
smpbo
);
const
dim3
num_blocks
(
n_experts
,
1
,
1
);
const
dim3
block_size
(
warp_size
,
1
,
1
);
const
size_t
nbytes_shared
=
n_tokens
*
sizeof
(
mm_ids_helper_store
);
GGML_ASSERT
(
nbytes_shared
<=
smpbo
);
mm_ids_helper
<
n_expert_used_template
><<<
num_blocks
,
block_size
,
nbytes_shared
,
stream
>>>
(
ids
,
ids_src1
,
ids_dst
,
expert_bounds
,
n_tokens
,
n_expert_used_var
,
nchannels_y
,
si1
,
sis1
);
}
void
ggml_cuda_launch_mm_ids_helper
(
const
int32_t
*
__restrict__
ids
,
int32_t
*
__restrict__
ids_src1
,
int32_t
*
__restrict__
ids_dst
,
int32_t
*
__restrict__
expert_bounds
,
const
int
n_experts
,
const
int
n_tokens
,
const
int
n_expert_used
,
const
int
nchannels_y
,
const
int
si1
,
const
int
sis1
,
cudaStream_t
stream
)
{
switch
(
n_expert_used
)
{
case
2
:
launch_mm_ids_helper
<
2
>
(
ids
,
ids_src1
,
ids_dst
,
expert_bounds
,
n_experts
,
n_tokens
,
n_expert_used
,
nchannels_y
,
si1
,
sis1
,
stream
);
break
;
case
4
:
launch_mm_ids_helper
<
4
>
(
ids
,
ids_src1
,
ids_dst
,
expert_bounds
,
n_experts
,
n_tokens
,
n_expert_used
,
nchannels_y
,
si1
,
sis1
,
stream
);
break
;
case
6
:
launch_mm_ids_helper
<
6
>
(
ids
,
ids_src1
,
ids_dst
,
expert_bounds
,
n_experts
,
n_tokens
,
n_expert_used
,
nchannels_y
,
si1
,
sis1
,
stream
);
break
;
case
8
:
launch_mm_ids_helper
<
8
>
(
ids
,
ids_src1
,
ids_dst
,
expert_bounds
,
n_experts
,
n_tokens
,
n_expert_used
,
nchannels_y
,
si1
,
sis1
,
stream
);
break
;
case
16
:
launch_mm_ids_helper
<
16
>
(
ids
,
ids_src1
,
ids_dst
,
expert_bounds
,
n_experts
,
n_tokens
,
n_expert_used
,
nchannels_y
,
si1
,
sis1
,
stream
);
break
;
case
32
:
launch_mm_ids_helper
<
32
>
(
ids
,
ids_src1
,
ids_dst
,
expert_bounds
,
n_experts
,
n_tokens
,
n_expert_used
,
nchannels_y
,
si1
,
sis1
,
stream
);
break
;
default:
launch_mm_ids_helper
<
0
>
(
ids
,
ids_src1
,
ids_dst
,
expert_bounds
,
n_experts
,
n_tokens
,
n_expert_used
,
nchannels_y
,
si1
,
sis1
,
stream
);
break
;
}
}
ml/backend/ggml/ggml/src/ggml-cuda/mmid.cuh
0 → 100644
View file @
544b6739
#pragma once
void
ggml_cuda_launch_mm_ids_helper
(
const
int32_t
*
ids
,
int32_t
*
ids_src1
,
int32_t
*
ids_dst
,
int32_t
*
expert_bounds
,
int
n_experts
,
int
n_tokens
,
int
n_expert_used
,
int
nchannels_y
,
int
si1
,
int
sis1
,
cudaStream_t
stream
);
ml/backend/ggml/ggml/src/ggml-cuda/mmq.cu
View file @
544b6739
#include "mmq.cuh"
#include "quantize.cuh"
#include <vector>
// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
struct
mmq_ids_helper_store
{
uint32_t
data
;
__device__
mmq_ids_helper_store
(
const
uint32_t
it
,
const
uint32_t
iex_used
)
{
data
=
(
it
&
0x003FFFFF
)
|
(
iex_used
<<
22
);
}
__device__
uint32_t
it
()
const
{
return
data
&
0x003FFFFF
;
}
__device__
uint32_t
iex_used
()
const
{
return
data
>>
22
;
}
};
static_assert
(
sizeof
(
mmq_ids_helper_store
)
==
4
,
"unexpected size for mmq_ids_helper_store"
);
// Helper function for mul_mat_id, converts ids to a more convenient format.
// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
// ids_dst describes the same mapping but for the dst tensor.
// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
template
<
int
n_expert_used_template
>
__launch_bounds__
(
ggml_cuda_get_physical_warp_size
(),
1
)
static
__global__
void
mmq_ids_helper
(
const
int32_t
*
__restrict__
ids
,
int32_t
*
__restrict__
ids_src1
,
int32_t
*
__restrict__
ids_dst
,
int32_t
*
__restrict__
expert_bounds
,
const
int
n_tokens
,
const
int
n_expert_used_var
,
const
int
nchannels_y
,
const
int
si1
,
const
int
sis1
)
{
constexpr
int
warp_size
=
ggml_cuda_get_physical_warp_size
();
const
int
n_expert_used
=
n_expert_used_template
==
0
?
n_expert_used_var
:
n_expert_used_template
;
const
int
expert
=
blockIdx
.
x
;
extern
__shared__
char
data_mmq_ids_helper
[];
mmq_ids_helper_store
*
store
=
(
mmq_ids_helper_store
*
)
data_mmq_ids_helper
;
int
nex_prev
=
0
;
// Number of columns for experts with a lower index.
int
it_compact
=
0
;
// Running index for the compact slice of this expert.
if
constexpr
(
n_expert_used_template
==
0
)
{
// Generic implementation:
for
(
int
it
=
0
;
it
<
n_tokens
;
++
it
)
{
int
iex_used
=
-
1
;
// The index at which the expert is used, if any.
for
(
int
iex
=
threadIdx
.
x
;
iex
<
n_expert_used
;
iex
+=
warp_size
)
{
const
int
expert_used
=
ids
[
it
*
si1
+
iex
];
nex_prev
+=
expert_used
<
expert
;
if
(
expert_used
==
expert
)
{
iex_used
=
iex
;
}
}
if
(
iex_used
!=
-
1
)
{
store
[
it_compact
]
=
mmq_ids_helper_store
(
it
,
iex_used
);
}
if
(
warp_reduce_any
<
warp_size
>
(
iex_used
!=
-
1
))
{
it_compact
++
;
}
}
}
else
{
// Implementation optimized for specific numbers of experts used:
static_assert
(
n_expert_used
==
6
||
warp_size
%
n_expert_used
==
0
,
"bad n_expert_used"
);
const
int
neu_padded
=
n_expert_used
==
6
?
8
:
n_expert_used
;
// Padded to next higher power of 2.
for
(
int
it0
=
0
;
it0
<
n_tokens
;
it0
+=
warp_size
/
neu_padded
)
{
const
int
it
=
it0
+
threadIdx
.
x
/
neu_padded
;
const
int
iex
=
threadIdx
.
x
%
neu_padded
;
// The index at which the expert is used, if any.
const
int
expert_used
=
(
neu_padded
==
n_expert_used
||
iex
<
n_expert_used
)
&&
it
<
n_tokens
?
ids
[
it
*
si1
+
iex
]
:
INT_MAX
;
const
int
iex_used
=
expert_used
==
expert
?
iex
:
-
1
;
nex_prev
+=
expert_used
<
expert
;
// Whether the threads at this token position have used the expert:
const
int
it_compact_add_self
=
warp_reduce_any
<
neu_padded
>
(
iex_used
!=
-
1
);
// Do a scan over threads at lower token positions in warp to get the correct index for writing data:
int
it_compact_add_lower
=
0
;
#pragma unroll
for
(
int
offset
=
neu_padded
;
offset
<
warp_size
;
offset
+=
neu_padded
)
{
const
int
tmp
=
__shfl_up_sync
(
0xFFFFFFFF
,
it_compact_add_self
,
offset
,
warp_size
);
if
(
threadIdx
.
x
>=
static_cast
<
unsigned
int
>
(
offset
))
{
it_compact_add_lower
+=
tmp
;
}
}
if
(
iex_used
!=
-
1
)
{
store
[
it_compact
+
it_compact_add_lower
]
=
mmq_ids_helper_store
(
it
,
iex_used
);
}
// The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
it_compact
+=
__shfl_sync
(
0xFFFFFFFF
,
it_compact_add_lower
+
it_compact_add_self
,
warp_size
-
1
,
warp_size
);
}
}
nex_prev
=
warp_reduce_sum
<
warp_size
>
(
nex_prev
);
for
(
int
itc
=
threadIdx
.
x
;
itc
<
it_compact
;
itc
+=
warp_size
)
{
const
mmq_ids_helper_store
store_it
=
store
[
itc
];
const
int
it
=
store_it
.
it
();
const
int
iex_used
=
store_it
.
iex_used
();
ids_src1
[
nex_prev
+
itc
]
=
it
*
sis1
+
iex_used
%
nchannels_y
;
ids_dst
[
nex_prev
+
itc
]
=
it
*
n_expert_used
+
iex_used
;
}
if
(
threadIdx
.
x
!=
0
)
{
return
;
}
expert_bounds
[
expert
]
=
nex_prev
;
if
(
expert
<
static_cast
<
int
>
(
gridDim
.
x
)
-
1
)
{
return
;
}
expert_bounds
[
gridDim
.
x
]
=
nex_prev
+
it_compact
;
}
template
<
int
n_expert_used_template
>
static
void
launch_mmq_ids_helper
(
const
int32_t
*
__restrict__
ids
,
int32_t
*
__restrict__
ids_src1
,
int32_t
*
__restrict__
ids_dst
,
int32_t
*
__restrict__
expert_bounds
,
const
int
n_experts
,
const
int
n_tokens
,
const
int
n_expert_used_var
,
const
int
nchannels_y
,
const
int
si1
,
const
int
sis1
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
n_tokens
<
(
1
<<
22
)
&&
"too few bits in mmq_ids_helper_store"
);
GGML_ASSERT
(
n_expert_used_var
<
(
1
<<
10
)
&&
"too few bits in mmq_ids_helper_store"
);
const
int
id
=
ggml_cuda_get_device
();
const
int
warp_size
=
ggml_cuda_info
().
devices
[
id
].
warp_size
;
const
size_t
smpbo
=
ggml_cuda_info
().
devices
[
id
].
smpbo
;
CUDA_SET_SHARED_MEMORY_LIMIT
(
mmq_ids_helper
<
n_expert_used_template
>
,
smpbo
);
const
dim3
num_blocks
(
n_experts
,
1
,
1
);
const
dim3
block_size
(
warp_size
,
1
,
1
);
const
size_t
nbytes_shared
=
n_tokens
*
sizeof
(
mmq_ids_helper_store
);
GGML_ASSERT
(
nbytes_shared
<=
smpbo
);
mmq_ids_helper
<
n_expert_used_template
><<<
num_blocks
,
block_size
,
nbytes_shared
,
stream
>>>
(
ids
,
ids_src1
,
ids_dst
,
expert_bounds
,
n_tokens
,
n_expert_used_var
,
nchannels_y
,
si1
,
sis1
);
}
#include "mmid.cuh"
static
void
ggml_cuda_mul_mat_q_switch_type
(
ggml_backend_cuda_context
&
ctx
,
const
mmq_args
&
args
,
cudaStream_t
stream
)
{
switch
(
args
.
type_x
)
{
...
...
@@ -293,36 +158,8 @@ void ggml_cuda_mul_mat_q(
const
int
si1
=
ids
->
nb
[
1
]
/
ggml_element_size
(
ids
);
const
int
sis1
=
nb12
/
nb11
;
switch
(
n_expert_used
)
{
case
2
:
launch_mmq_ids_helper
<
2
>
((
const
int32_t
*
)
ids
->
data
,
ids_src1
.
get
(),
ids_dst
.
get
(),
expert_bounds
.
get
(),
ggml_cuda_launch_mm_ids_helper
((
const
int32_t
*
)
ids
->
data
,
ids_src1
.
get
(),
ids_dst
.
get
(),
expert_bounds
.
get
(),
ne02
,
ne12
,
n_expert_used
,
ne11
,
si1
,
sis1
,
stream
);
break
;
case
4
:
launch_mmq_ids_helper
<
4
>
((
const
int32_t
*
)
ids
->
data
,
ids_src1
.
get
(),
ids_dst
.
get
(),
expert_bounds
.
get
(),
ne02
,
ne12
,
n_expert_used
,
ne11
,
si1
,
sis1
,
stream
);
break
;
case
6
:
launch_mmq_ids_helper
<
6
>
((
const
int32_t
*
)
ids
->
data
,
ids_src1
.
get
(),
ids_dst
.
get
(),
expert_bounds
.
get
(),
ne02
,
ne12
,
n_expert_used
,
ne11
,
si1
,
sis1
,
stream
);
break
;
case
8
:
launch_mmq_ids_helper
<
8
>
((
const
int32_t
*
)
ids
->
data
,
ids_src1
.
get
(),
ids_dst
.
get
(),
expert_bounds
.
get
(),
ne02
,
ne12
,
n_expert_used
,
ne11
,
si1
,
sis1
,
stream
);
break
;
case
16
:
launch_mmq_ids_helper
<
16
>
((
const
int32_t
*
)
ids
->
data
,
ids_src1
.
get
(),
ids_dst
.
get
(),
expert_bounds
.
get
(),
ne02
,
ne12
,
n_expert_used
,
ne11
,
si1
,
sis1
,
stream
);
break
;
case
32
:
launch_mmq_ids_helper
<
32
>
((
const
int32_t
*
)
ids
->
data
,
ids_src1
.
get
(),
ids_dst
.
get
(),
expert_bounds
.
get
(),
ne02
,
ne12
,
n_expert_used
,
ne11
,
si1
,
sis1
,
stream
);
break
;
default:
launch_mmq_ids_helper
<
0
>
((
const
int32_t
*
)
ids
->
data
,
ids_src1
.
get
(),
ids_dst
.
get
(),
expert_bounds
.
get
(),
ne02
,
ne12
,
n_expert_used
,
ne11
,
si1
,
sis1
,
stream
);
break
;
}
CUDA_CHECK
(
cudaGetLastError
());
}
...
...
ml/backend/ggml/ggml/src/ggml-cuda/mmvf.cu
View file @
544b6739
...
...
@@ -7,14 +7,14 @@ template <typename T, typename type_acc, int ncols_dst, int block_size>
static
__global__
void
mul_mat_vec_f
(
const
T
*
__restrict__
x
,
const
float
*
__restrict__
y
,
const
int32_t
*
__restrict__
ids
,
float
*
__restrict__
dst
,
const
int
ncols2
,
const
int
nchannels_y
,
const
int
stride_row
,
const
int
stride_col_y2
,
const
int
stride_col_dst
,
const
int
channel_ratio
,
const
int
stride_channel_x
,
const
int
stride_channel_y
,
const
int
stride_channel_dst
,
const
int
sample_ratio
,
const
int
stride_sample_x
,
const
int
stride_sample_y
,
const
int
stride_sample_dst
)
{
const
u
int
3
channel_ratio
,
const
int
stride_channel_x
,
const
int
stride_channel_y
,
const
int
stride_channel_dst
,
const
u
int
3
sample_ratio
,
const
int
stride_sample_x
,
const
int
stride_sample_y
,
const
int
stride_sample_dst
)
{
const
int
row
=
blockIdx
.
x
;
const
int
channel_dst
=
blockIdx
.
y
;
const
int
channel_x
=
ids
?
ids
[
channel_dst
]
:
channel_dst
/
channel_ratio
;
const
int
channel_x
=
ids
?
ids
[
channel_dst
]
:
fastdiv
((
uint32_t
)
channel_dst
,
channel_ratio
)
;
const
int
channel_y
=
ids
?
channel_dst
%
nchannels_y
:
channel_dst
;
const
int
sample_dst
=
blockIdx
.
z
;
const
int
sample_x
=
sample_dst
/
sample_ratio
;
const
int
sample_x
=
fastdiv
((
uint32_t
)
sample_dst
,
sample_ratio
)
;
const
int
sample_y
=
sample_dst
;
const
int
tid
=
threadIdx
.
x
;
...
...
@@ -47,8 +47,8 @@ static __global__ void mul_mat_vec_f(
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols_dst
;
++
j
)
{
const
float2
tmpy
=
y2
[
j
*
stride_col_y2
+
col2
];
sumf
[
j
]
+=
tmpx
.
x
*
tmpy
.
x
;
sumf
[
j
]
+=
tmpx
.
y
*
tmpy
.
y
;
ggml_cuda_mad
(
sumf
[
j
]
,
tmpx
.
x
,
tmpy
.
x
)
;
ggml_cuda_mad
(
sumf
[
j
]
,
tmpx
.
y
,
tmpy
.
y
)
;
}
}
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
half
>
)
{
...
...
@@ -61,8 +61,8 @@ static __global__ void mul_mat_vec_f(
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols_dst
;
++
j
)
{
const
float2
tmpy
=
y2
[
j
*
stride_col_y2
+
col2
];
sumf
[
j
]
+=
tmpx
.
x
*
tmpy
.
x
;
sumf
[
j
]
+=
tmpx
.
y
*
tmpy
.
y
;
ggml_cuda_mad
(
sumf
[
j
]
,
tmpx
.
x
,
tmpy
.
x
)
;
ggml_cuda_mad
(
sumf
[
j
]
,
tmpx
.
y
,
tmpy
.
y
)
;
}
}
}
else
{
...
...
@@ -88,16 +88,32 @@ static __global__ void mul_mat_vec_f(
#endif // FP16_AVAILABLE
}
}
else
if
constexpr
(
std
::
is_same_v
<
T
,
nv_bfloat16
>
)
{
//TODO: add support for ggml_cuda_mad for hip_bfloat162
#if defined(GGML_USE_HIP)
const
int
*
x2
=
(
const
int
*
)
x
;
for
(
int
col2
=
tid
;
col2
<
ncols2
;
col2
+=
block_size
)
{
const
int
tmpx
=
x2
[
col2
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols_dst
;
++
j
)
{
const
float2
tmpy
=
y2
[
j
*
stride_col_y2
+
col2
];
sumf
[
j
]
+=
ggml_cuda_cast
<
float
>
(
reinterpret_cast
<
const
nv_bfloat16
*>
(
&
tmpx
)[
0
])
*
tmpy
.
x
;
sumf
[
j
]
+=
ggml_cuda_cast
<
float
>
(
reinterpret_cast
<
const
nv_bfloat16
*>
(
&
tmpx
)[
1
])
*
tmpy
.
y
;
const
float
tmpx0
=
ggml_cuda_cast
<
float
>
(
reinterpret_cast
<
const
nv_bfloat16
*>
(
&
tmpx
)[
0
]);
const
float
tmpx1
=
ggml_cuda_cast
<
float
>
(
reinterpret_cast
<
const
nv_bfloat16
*>
(
&
tmpx
)[
1
]);
ggml_cuda_mad
(
sumf
[
j
],
tmpx0
,
tmpy
.
x
);
ggml_cuda_mad
(
sumf
[
j
],
tmpx1
,
tmpy
.
y
);
}
}
#else
const
nv_bfloat162
*
x2
=
(
const
nv_bfloat162
*
)
x
;
for
(
int
col2
=
tid
;
col2
<
ncols2
;
col2
+=
block_size
)
{
const
nv_bfloat162
tmpx
=
x2
[
col2
];
#pragma unroll
for
(
int
j
=
0
;
j
<
ncols_dst
;
++
j
)
{
const
float2
tmpy
=
y2
[
j
*
stride_col_y2
+
col2
];
ggml_cuda_mad
(
sumf
[
j
],
tmpx
.
x
,
tmpy
.
x
);
ggml_cuda_mad
(
sumf
[
j
],
tmpx
.
y
,
tmpy
.
y
);
}
}
#endif
}
else
{
static_assert
(
std
::
is_same_v
<
T
,
void
>
,
"unsupported type"
);
}
...
...
@@ -140,8 +156,8 @@ static void launch_mul_mat_vec_f_cuda(
GGML_ASSERT
(
stride_col_y
%
2
==
0
);
GGML_ASSERT
(
ids
||
nchannels_dst
%
nchannels_x
==
0
);
GGML_ASSERT
(
nsamples_dst
%
nsamples_x
==
0
);
const
int
64_t
channel_ratio
=
nchannels_dst
/
nchannels_x
;
const
int
64_t
sample_ratio
=
nsamples_dst
/
nsamples_x
;
const
u
int
3
channel_ratio
_fd
=
ids
?
make_uint3
(
0
,
0
,
0
)
:
init_fastdiv_values
(
nchannels_dst
/
nchannels_x
)
;
const
u
int
3
sample_ratio
_fd
=
init_fastdiv_values
(
nsamples_dst
/
nsamples_x
)
;
const
int
device
=
ggml_cuda_get_device
();
const
int
warp_size
=
ggml_cuda_info
().
devices
[
device
].
warp_size
;
...
...
@@ -167,50 +183,50 @@ static void launch_mul_mat_vec_f_cuda(
case
32
:
{
mul_mat_vec_f
<
T
,
type_acc
,
ncols_dst
,
32
><<<
block_nums
,
block_dims
,
nbytes_shared
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
stride_col_y
/
2
,
stride_col_dst
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
channel_ratio
_fd
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
_fd
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
case
64
:
{
mul_mat_vec_f
<
T
,
type_acc
,
ncols_dst
,
64
><<<
block_nums
,
block_dims
,
nbytes_shared
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
stride_col_y
/
2
,
stride_col_dst
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
channel_ratio
_fd
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
_fd
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
case
96
:
{
mul_mat_vec_f
<
T
,
type_acc
,
ncols_dst
,
96
><<<
block_nums
,
block_dims
,
nbytes_shared
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
stride_col_y
/
2
,
stride_col_dst
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
channel_ratio
_fd
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
_fd
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
case
128
:
{
mul_mat_vec_f
<
T
,
type_acc
,
ncols_dst
,
128
><<<
block_nums
,
block_dims
,
nbytes_shared
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
stride_col_y
/
2
,
stride_col_dst
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
channel_ratio
_fd
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
_fd
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
case
160
:
{
mul_mat_vec_f
<
T
,
type_acc
,
ncols_dst
,
160
><<<
block_nums
,
block_dims
,
nbytes_shared
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
stride_col_y
/
2
,
stride_col_dst
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
channel_ratio
_fd
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
_fd
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
case
192
:
{
mul_mat_vec_f
<
T
,
type_acc
,
ncols_dst
,
192
><<<
block_nums
,
block_dims
,
nbytes_shared
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
stride_col_y
/
2
,
stride_col_dst
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
channel_ratio
_fd
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
_fd
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
case
224
:
{
mul_mat_vec_f
<
T
,
type_acc
,
ncols_dst
,
224
><<<
block_nums
,
block_dims
,
nbytes_shared
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
stride_col_y
/
2
,
stride_col_dst
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
channel_ratio
_fd
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
_fd
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
case
256
:
{
mul_mat_vec_f
<
T
,
type_acc
,
ncols_dst
,
256
><<<
block_nums
,
block_dims
,
nbytes_shared
,
stream
>>>
(
x
,
y
,
ids
,
dst
,
ncols
/
2
,
nchannels_y
,
stride_row
,
stride_col_y
/
2
,
stride_col_dst
,
channel_ratio
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
channel_ratio
_fd
,
stride_channel_x
,
stride_channel_y
,
stride_channel_dst
,
sample_ratio
_fd
,
stride_sample_x
,
stride_sample_y
,
stride_sample_dst
);
}
break
;
default:
{
GGML_ABORT
(
"fatal error"
);
...
...
ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu
View file @
544b6739
...
...
@@ -4,16 +4,61 @@
#include <initializer_list>
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
template
<
int
experts_per_thread
,
bool
use_limit
>
__device__
void
softmax_warp_inplace
(
float
(
&
vals
)[
experts_per_thread
],
const
int
limit
,
const
int
lane
)
{
float
max_val
=
-
INFINITY
;
#pragma unroll
for
(
int
i
=
0
;
i
<
experts_per_thread
;
i
++
)
{
const
int
idx
=
lane
+
i
*
WARP_SIZE
;
const
bool
active
=
!
use_limit
||
(
idx
<
limit
);
if
(
active
)
{
max_val
=
max
(
max_val
,
vals
[
i
]);
}
}
max_val
=
warp_reduce_max
(
max_val
);
float
sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
experts_per_thread
;
i
++
)
{
const
int
idx
=
lane
+
i
*
WARP_SIZE
;
const
bool
active
=
!
use_limit
||
(
idx
<
limit
);
if
(
active
)
{
const
float
val
=
expf
(
vals
[
i
]
-
max_val
);
vals
[
i
]
=
val
;
sum
+=
val
;
}
else
{
vals
[
i
]
=
0.
f
;
}
}
sum
=
warp_reduce_sum
(
sum
);
const
float
inv_sum
=
1.0
f
/
sum
;
#pragma unroll
for
(
int
i
=
0
;
i
<
experts_per_thread
;
i
++
)
{
const
int
idx
=
lane
+
i
*
WARP_SIZE
;
const
bool
active
=
!
use_limit
||
(
idx
<
limit
);
if
(
active
)
{
vals
[
i
]
*=
inv_sum
;
}
}
}
/*
This kernel does the following:
1. softmax over the logits per token [n_experts, n_tokens]
1.
optionally
softmax over the logits per token [n_experts, n_tokens]
2. argmax reduce over the top-k (n_experts_used) logits
3. write weights + ids to global memory
4. optionally normalize the weights
4. optionally normalize the weights
or apply softmax over the selected logits
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
*/
template
<
int
n_experts
,
bool
with_norm
>
template
<
int
n_experts
,
bool
with_norm
,
bool
delayed_softmax
=
false
>
__launch_bounds__
(
4
*
WARP_SIZE
,
1
)
__global__
void
topk_moe_cuda
(
const
float
*
logits
,
float
*
weights
,
int32_t
*
ids
,
...
...
@@ -30,52 +75,31 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
constexpr
int
experts_per_thread
=
(
n_experts
>
WARP_SIZE
)
?
n_experts
/
WARP_SIZE
:
1
;
float
logits_r
[
experts_per_thread
];
float
wt
[
experts_per_thread
];
#pragma unroll
for
(
int
i
=
0
;
i
<
n_experts
;
i
+=
WARP_SIZE
)
{
const
int
expert
=
i
+
threadIdx
.
x
;
logits_r
[
i
/
WARP_SIZE
]
=
n_experts
%
WARP_SIZE
==
0
||
expert
<
n_experts
?
logits
[
expert
]
:
-
INFINITY
;
wt
[
i
/
WARP_SIZE
]
=
(
n_experts
%
WARP_SIZE
==
0
||
expert
<
n_experts
)
?
logits
[
expert
]
:
-
INFINITY
;
}
float
max_val
=
logits_r
[
0
];
#pragma unroll
for
(
int
i
=
1
;
i
<
experts_per_thread
;
i
++
)
{
const
float
val
=
logits_r
[
i
];
max_val
=
max
(
val
,
max_val
);
if
constexpr
(
!
delayed_softmax
)
{
softmax_warp_inplace
<
experts_per_thread
,
false
>
(
wt
,
n_experts
,
threadIdx
.
x
);
}
max_val
=
warp_reduce_max
(
max_val
);
float
wt
[
experts_per_thread
];
float
tmp
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
experts_per_thread
;
i
++
)
{
const
float
val
=
logits_r
[
i
];
wt
[
i
]
=
expf
(
val
-
max_val
);
tmp
+=
wt
[
i
];
}
//at this point, each thread holds either a portion of the softmax distribution
//or the raw logits. We do the argmax reduce over n_expert_used, each time marking
//the expert weight as -inf to exclude from the next iteration
tmp
=
warp_reduce_sum
(
tmp
)
;
float
wt_sum
=
0.
f
;
const
float
inv_sum
=
1.0
f
/
tmp
;
float
output_weights
[
experts_per_thread
]
;
#pragma unroll
for
(
int
i
=
0
;
i
<
experts_per_thread
;
i
++
)
{
wt
[
i
]
=
wt
[
i
]
*
inv_sum
;
output_weights
[
i
]
=
0.
f
;
}
//at this point, each thread holds a portion of softmax,
//we do the argmax reduce over n_expert_used, each time marking
//the expert weight as -inf to exclude from the next iteration
float
wt_sum
=
0.
f
;
extern
__shared__
float
data_topk_shared
[];
float
*
wt_shared_ptr
=
data_topk_shared
+
threadIdx
.
y
*
n_expert_used
;
for
(
int
k
=
0
;
k
<
n_expert_used
;
k
++
)
{
float
max_val
=
wt
[
0
];
int
max_expert
=
threadIdx
.
x
;
...
...
@@ -99,10 +123,13 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
}
}
if
((
k
&
(
WARP_SIZE
-
1
))
==
threadIdx
.
x
)
{
output_weights
[
k
/
WARP_SIZE
]
=
max_val
;
}
if
((
max_expert
&
(
WARP_SIZE
-
1
))
==
threadIdx
.
x
)
{
wt
[
max_expert
/
WARP_SIZE
]
=
-
INFINITY
;
wt_shared_ptr
[
k
]
=
max_val
;
ids
[
k
]
=
max_expert
;
if
constexpr
(
with_norm
)
{
wt_sum
+=
max_val
;
...
...
@@ -114,17 +141,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
wt_sum
=
warp_reduce_sum
(
wt_sum
);
const
float
inv_sum
=
1.0
f
/
wt_sum
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n_
expert
_used
;
i
+=
WARP_SIZE
)
{
wt_shared_ptr
[
i
]
=
wt_shared_ptr
[
i
]
*
inv_sum
;
for
(
int
i
=
0
;
i
<
expert
s_per_thread
;
i
++
)
{
output_weights
[
i
]
*
=
inv_sum
;
}
}
for
(
int
i
=
threadIdx
.
x
;
i
<
n_expert_used
;
i
+=
WARP_SIZE
)
{
weights
[
i
]
=
wt_shared_ptr
[
i
];
if
constexpr
(
delayed_softmax
)
{
softmax_warp_inplace
<
experts_per_thread
,
true
>
(
output_weights
,
n_expert_used
,
threadIdx
.
x
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
experts_per_thread
;
i
++
)
{
const
int
idx
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
if
(
idx
<
n_expert_used
)
{
weights
[
idx
]
=
output_weights
[
i
];
}
}
}
template
<
bool
with_norm
>
template
<
bool
with_norm
,
bool
delayed_softmax
=
false
>
static
void
launch_topk_moe_cuda
(
ggml_backend_cuda_context
&
ctx
,
const
float
*
logits
,
float
*
weights
,
...
...
@@ -132,53 +167,53 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const
int
n_rows
,
const
int
n_expert
,
const
int
n_expert_used
)
{
static_assert
(
!
(
with_norm
&&
delayed_softmax
),
"delayed softmax is not supported with weight normalization"
);
const
int
rows_per_block
=
4
;
dim3
grid_dims
((
n_rows
+
rows_per_block
-
1
)
/
rows_per_block
,
1
,
1
);
dim3
block_dims
(
WARP_SIZE
,
rows_per_block
,
1
);
cudaStream_t
stream
=
ctx
.
stream
();
const
int
nbytes_shared
=
n_expert_used
*
rows_per_block
*
sizeof
(
float
);
switch
(
n_expert
)
{
case
1
:
topk_moe_cuda
<
1
,
with_norm
>
<<<
grid_dims
,
block_dims
,
nbytes_shared
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
topk_moe_cuda
<
1
,
with_norm
,
delayed_softmax
>
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
break
;
case
2
:
topk_moe_cuda
<
2
,
with_norm
>
<<<
grid_dims
,
block_dims
,
nbytes_shared
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
topk_moe_cuda
<
2
,
with_norm
,
delayed_softmax
>
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
break
;
case
4
:
topk_moe_cuda
<
4
,
with_norm
>
<<<
grid_dims
,
block_dims
,
nbytes_shared
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
topk_moe_cuda
<
4
,
with_norm
,
delayed_softmax
>
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
break
;
case
8
:
topk_moe_cuda
<
8
,
with_norm
>
<<<
grid_dims
,
block_dims
,
nbytes_shared
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
topk_moe_cuda
<
8
,
with_norm
,
delayed_softmax
>
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
break
;
case
16
:
topk_moe_cuda
<
16
,
with_norm
>
<<<
grid_dims
,
block_dims
,
nbytes_shared
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
topk_moe_cuda
<
16
,
with_norm
,
delayed_softmax
>
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
break
;
case
32
:
topk_moe_cuda
<
32
,
with_norm
>
<<<
grid_dims
,
block_dims
,
nbytes_shared
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
topk_moe_cuda
<
32
,
with_norm
,
delayed_softmax
>
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
break
;
case
64
:
topk_moe_cuda
<
64
,
with_norm
>
<<<
grid_dims
,
block_dims
,
nbytes_shared
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
topk_moe_cuda
<
64
,
with_norm
,
delayed_softmax
>
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
break
;
case
128
:
topk_moe_cuda
<
128
,
with_norm
>
<<<
grid_dims
,
block_dims
,
nbytes_shared
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
topk_moe_cuda
<
128
,
with_norm
,
delayed_softmax
>
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
break
;
case
256
:
topk_moe_cuda
<
256
,
with_norm
>
<<<
grid_dims
,
block_dims
,
nbytes_shared
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
topk_moe_cuda
<
256
,
with_norm
,
delayed_softmax
>
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
break
;
case
512
:
topk_moe_cuda
<
512
,
with_norm
>
<<<
grid_dims
,
block_dims
,
nbytes_shared
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
topk_moe_cuda
<
512
,
with_norm
,
delayed_softmax
>
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
logits
,
weights
,
ids
,
n_rows
,
n_expert_used
);
break
;
default:
GGML_ASSERT
(
false
&&
"fatal error"
);
...
...
@@ -190,7 +225,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const
ggml_tensor
*
logits
,
ggml_tensor
*
weights
,
ggml_tensor
*
ids
,
const
bool
with_norm
)
{
const
bool
with_norm
,
const
bool
delayed_softmax
)
{
GGML_ASSERT
(
logits
->
type
==
GGML_TYPE_F32
);
GGML_ASSERT
(
weights
->
type
==
GGML_TYPE_F32
);
GGML_ASSERT
(
ids
->
type
==
GGML_TYPE_I32
);
...
...
@@ -198,7 +234,7 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const
int
n_experts
=
logits
->
ne
[
0
];
const
int
n_rows
=
logits
->
ne
[
1
];
const
float
*
logits_d
=
(
const
float
*
)
logits
->
src
[
0
]
->
data
;
const
float
*
logits_d
=
(
const
float
*
)
logits
->
data
;
float
*
weights_d
=
(
float
*
)
weights
->
data
;
int32_t
*
ids_d
=
(
int32_t
*
)
ids
->
data
;
...
...
@@ -209,7 +245,11 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
if
(
with_norm
)
{
launch_topk_moe_cuda
<
true
>
(
ctx
,
logits_d
,
weights_d
,
ids_d
,
n_rows
,
n_experts
,
n_expert_used
);
}
else
{
launch_topk_moe_cuda
<
false
>
(
ctx
,
logits_d
,
weights_d
,
ids_d
,
n_rows
,
n_experts
,
n_expert_used
);
if
(
delayed_softmax
)
{
launch_topk_moe_cuda
<
false
,
true
>
(
ctx
,
logits_d
,
weights_d
,
ids_d
,
n_rows
,
n_experts
,
n_expert_used
);
}
else
{
launch_topk_moe_cuda
<
false
,
false
>
(
ctx
,
logits_d
,
weights_d
,
ids_d
,
n_rows
,
n_experts
,
n_expert_used
);
}
}
}
...
...
@@ -242,7 +282,7 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
return
true
;
}
std
::
initializer_list
<
enum
ggml_op
>
ggml_cuda_topk_moe_ops
(
bool
norm
)
{
std
::
initializer_list
<
enum
ggml_op
>
ggml_cuda_topk_moe_ops
(
bool
norm
,
bool
delayed_softmax
)
{
static
std
::
initializer_list
<
enum
ggml_op
>
norm_ops
=
{
GGML_OP_SOFT_MAX
,
GGML_OP_RESHAPE
,
GGML_OP_ARGSORT
,
GGML_OP_VIEW
,
GGML_OP_GET_ROWS
,
GGML_OP_RESHAPE
,
GGML_OP_SUM_ROWS
,
GGML_OP_DIV
,
GGML_OP_RESHAPE
};
...
...
@@ -250,8 +290,19 @@ std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm) {
static
std
::
initializer_list
<
enum
ggml_op
>
no_norm_ops
=
{
GGML_OP_SOFT_MAX
,
GGML_OP_RESHAPE
,
GGML_OP_ARGSORT
,
GGML_OP_VIEW
,
GGML_OP_GET_ROWS
};
static
std
::
initializer_list
<
enum
ggml_op
>
delayed_softmax_ops
=
{
GGML_OP_ARGSORT
,
GGML_OP_VIEW
,
GGML_OP_GET_ROWS
,
GGML_OP_RESHAPE
,
GGML_OP_SOFT_MAX
,
GGML_OP_RESHAPE
};
GGML_ASSERT
(
!
norm
||
!
delayed_softmax
);
if
(
delayed_softmax
)
{
return
delayed_softmax_ops
;
}
if
(
norm
)
{
return
norm_ops
;
}
return
no_norm_ops
;
}
ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cuh
View file @
544b6739
...
...
@@ -6,9 +6,10 @@
void
ggml_cuda_op_topk_moe
(
ggml_backend_cuda_context
&
ctx
,
const
ggml_tensor
*
logits
,
ggml_tensor
*
weights
,
ggml_tensor
*
top_k
,
const
bool
with_norm
);
ggml_tensor
*
ids
,
const
bool
with_norm
,
const
bool
delayed_softmax
=
false
);
bool
ggml_cuda_should_use_topk_moe
(
const
ggml_tensor
*
softmax
,
const
ggml_tensor
*
weights
);
std
::
initializer_list
<
enum
ggml_op
>
ggml_cuda_topk_moe_ops
(
bool
with_norm
);
std
::
initializer_list
<
enum
ggml_op
>
ggml_cuda_topk_moe_ops
(
bool
with_norm
,
bool
delayed_softmax
=
false
);
ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt
View file @
544b6739
...
...
@@ -28,8 +28,10 @@ if (CXX_IS_HIPCC)
" Prefer setting the HIP compiler directly. See README for details."
)
endif
()
else
()
# Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
if
(
AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES
)
# Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
if
(
GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES
)
set
(
CMAKE_HIP_ARCHITECTURES
${
GPU_TARGETS
}
)
elseif
(
AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES
)
set
(
CMAKE_HIP_ARCHITECTURES
${
AMDGPU_TARGETS
}
)
endif
()
cmake_minimum_required
(
VERSION 3.21
)
...
...
ml/backend/ggml/ggml/src/ggml-impl.h
View file @
544b6739
...
...
@@ -565,14 +565,23 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
static
inline
int32_t
ggml_node_get_use_count
(
const
struct
ggml_cgraph
*
cgraph
,
int
node_idx
)
{
const
struct
ggml_tensor
*
node
=
cgraph
->
nodes
[
node_idx
];
size_t
hash_pos
=
ggml_hash_find
(
&
cgraph
->
visited_hash_set
,
node
);
if
(
!
ggml_bitset_get
(
cgraph
->
visited_hash_set
.
used
,
hash_pos
))
{
return
0
;
}
return
cgraph
->
use_counts
[
hash_pos
];
}
// return true if the node's results are only used by N other nodes
// and can be fused into their calculations.
static
inline
bool
ggml_node_has_n_uses
(
const
struct
ggml_cgraph
*
cgraph
,
int
node_idx
,
int32_t
n_uses
)
{
const
struct
ggml_tensor
*
node
=
cgraph
->
nodes
[
node_idx
];
// check the use count against how many we're replacing
size_t
hash_pos
=
ggml_hash_find
(
&
cgraph
->
visited_hash_set
,
node
);
if
(
!
ggml_bitset_get
(
cgraph
->
visited_hash_set
.
used
,
hash_pos
)
||
cgraph
->
use_counts
[
hash_pos
]
!=
n_uses
)
{
if
(
ggml_node_get_use_count
(
cgraph
,
node_idx
)
!=
n_uses
)
{
return
false
;
}
...
...
@@ -638,6 +647,36 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
return
ggml_can_fuse_ext
(
cgraph
,
idxs
,
ops
,
num_ops
);
}
GGML_API
bool
ggml_can_fuse_subgraph_ext
(
const
struct
ggml_cgraph
*
cgraph
,
const
int
*
node_idxs
,
int
count
,
const
enum
ggml_op
*
ops
,
const
int
*
outputs
,
int
num_outputs
);
// Returns true if the subgraph formed by {node_idxs} can be fused
// checks whethers all nodes which are not part of outputs can be elided
// by checking if their num_uses are confined to the subgraph
static
inline
bool
ggml_can_fuse_subgraph
(
const
struct
ggml_cgraph
*
cgraph
,
int
node_idx
,
int
count
,
const
enum
ggml_op
*
ops
,
const
int
*
outputs
,
int
num_outputs
)
{
GGML_ASSERT
(
count
<
32
);
if
(
node_idx
+
count
>
cgraph
->
n_nodes
)
{
return
false
;
}
int
idxs
[
32
];
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
idxs
[
i
]
=
node_idx
+
i
;
}
return
ggml_can_fuse_subgraph_ext
(
cgraph
,
idxs
,
count
,
ops
,
outputs
,
num_outputs
);
}
// Management libraries for fetching more accurate free VRAM data
GGML_API
int
ggml_nvml_init
();
GGML_API
int
ggml_nvml_get_device_memory
(
const
char
*
uuid
,
size_t
*
free
,
size_t
*
total
);
...
...
@@ -662,6 +701,13 @@ inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::
return
ggml_can_fuse
(
cgraph
,
node_idx
,
ops
.
begin
(),
(
int
)
ops
.
size
());
}
inline
bool
ggml_can_fuse_subgraph
(
const
struct
ggml_cgraph
*
cgraph
,
int
start_idx
,
std
::
initializer_list
<
enum
ggml_op
>
ops
,
std
::
initializer_list
<
int
>
outputs
=
{})
{
return
ggml_can_fuse_subgraph
(
cgraph
,
start_idx
,
ops
.
size
(),
ops
.
begin
(),
outputs
.
begin
(),
outputs
.
size
());
}
// expose GGUF internals for test code
GGML_API
size_t
gguf_type_size
(
enum
gguf_type
type
);
GGML_API
struct
gguf_context
*
gguf_init_from_file_impl
(
FILE
*
file
,
struct
gguf_init_params
params
);
...
...
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp
View file @
544b6739
...
...
@@ -1406,6 +1406,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met
return
res
;
}
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_conv_transpose_2d
(
ggml_metal_library_t
lib
,
const
ggml_tensor
*
op
)
{
assert
(
op
->
op
==
GGML_OP_CONV_TRANSPOSE_2D
);
GGML_ASSERT
(
ggml_is_contiguous
(
op
->
src
[
0
]));
GGML_ASSERT
(
ggml_is_contiguous
(
op
->
src
[
1
]));
GGML_ASSERT
(
op
->
src
[
0
]
->
type
==
GGML_TYPE_F16
||
op
->
src
[
0
]
->
type
==
GGML_TYPE_F32
);
GGML_ASSERT
(
op
->
src
[
1
]
->
type
==
GGML_TYPE_F32
);
GGML_ASSERT
(
op
->
type
==
GGML_TYPE_F32
);
char
base
[
256
];
char
name
[
256
];
snprintf
(
base
,
256
,
"kernel_conv_transpose_2d_%s_%s"
,
ggml_type_name
(
op
->
src
[
0
]
->
type
),
ggml_type_name
(
op
->
src
[
1
]
->
type
));
snprintf
(
name
,
256
,
"%s"
,
base
);
ggml_metal_pipeline_t
res
=
ggml_metal_library_get_pipeline
(
lib
,
name
);
if
(
res
)
{
return
res
;
}
res
=
ggml_metal_library_compile_pipeline
(
lib
,
base
,
name
,
nullptr
);
return
res
;
}
ggml_metal_pipeline_t
ggml_metal_library_get_pipeline_upscale
(
ggml_metal_library_t
lib
,
const
ggml_tensor
*
op
)
{
assert
(
op
->
op
==
GGML_OP_UPSCALE
);
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment