Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fcfc474d
Commit
fcfc474d
authored
Apr 09, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.3' into v0.8.3-dev
parents
bb94d2e5
296c6572
Changes
503
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
543 additions
and
238 deletions
+543
-238
csrc/quantization/gguf/dequantize.cuh
csrc/quantization/gguf/dequantize.cuh
+34
-31
csrc/quantization/gguf/ggml-common.h
csrc/quantization/gguf/ggml-common.h
+21
-1
csrc/quantization/gguf/gguf_kernel.cu
csrc/quantization/gguf/gguf_kernel.cu
+20
-15
csrc/quantization/gguf/moe.cuh
csrc/quantization/gguf/moe.cuh
+80
-80
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
+6
-6
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+14
-14
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
+8
-8
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
+5
-5
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
+7
-7
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
+7
-7
csrc/quantization/utils.cuh
csrc/quantization/utils.cuh
+59
-0
csrc/rocm/attention.cu
csrc/rocm/attention.cu
+99
-58
csrc/rocm/ops.h
csrc/rocm/ops.h
+3
-2
csrc/rocm/torch_bindings.cpp
csrc/rocm/torch_bindings.cpp
+3
-1
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+39
-3
docker/Dockerfile
docker/Dockerfile
+0
-0
docker/Dockerfile.arm
docker/Dockerfile.arm
+0
-0
docker/Dockerfile.cpu
docker/Dockerfile.cpu
+138
-0
docker/Dockerfile.hpu
docker/Dockerfile.hpu
+0
-0
docker/Dockerfile.neuron
docker/Dockerfile.neuron
+0
-0
No files found.
csrc/quantization/gguf/dequantize.cuh
View file @
fcfc474d
...
...
@@ -94,8 +94,8 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
dfloat2
v
;
dequantize_kernel
(
vx
,
ib
,
iqs
,
v
);
y
[
iybs
+
iqs
+
0
]
=
v
.
x
;
y
[
iybs
+
iqs
+
y_offset
]
=
v
.
y
;
y
[
iybs
+
iqs
+
0
]
=
convert_from_half
<
dst_t
>
(
v
.
x
)
;
y
[
iybs
+
iqs
+
y_offset
]
=
convert_from_half
<
dst_t
>
(
v
.
y
)
;
}
template
<
typename
dst_t
>
...
...
@@ -114,10 +114,10 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
half
dall
=
__low2half
(
x
[
i
].
dm
);
half
dmin
=
__high2half
(
x
[
i
].
dm
);
y
[
l
+
0
]
=
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
0
]
&
0xF
)
*
((
q
>>
0
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
0
]
>>
4
)));
y
[
l
+
32
]
=
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
2
]
&
0xF
)
*
((
q
>>
2
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
2
]
>>
4
)));
y
[
l
+
64
]
=
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
4
]
&
0xF
)
*
((
q
>>
4
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
4
]
>>
4
)));
y
[
l
+
96
]
=
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
6
]
&
0xF
)
*
((
q
>>
6
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
6
]
>>
4
)));
y
[
l
+
0
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
0
]
&
0xF
)
*
((
q
>>
0
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
0
]
>>
4
)))
)
;
y
[
l
+
32
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
2
]
&
0xF
)
*
((
q
>>
2
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
2
]
>>
4
)))
)
;
y
[
l
+
64
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
4
]
&
0xF
)
*
((
q
>>
4
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
4
]
>>
4
)))
)
;
y
[
l
+
96
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
dall
,
__int2half_rn
((
x
[
i
].
scales
[
is
+
6
]
&
0xF
)
*
((
q
>>
6
)
&
3
))),
__hmul
(
dmin
,
__int2half_rn
(
x
[
i
].
scales
[
is
+
6
]
>>
4
)))
)
;
}
template
<
typename
dst_t
>
...
...
@@ -148,7 +148,9 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
const
uint8_t
*
q
=
x
[
i
].
qs
+
32
*
n
;
const
uint8_t
*
hm
=
x
[
i
].
hmask
;
for
(
int
l
=
l0
;
l
<
l0
+
4
;
++
l
)
y
[
l
]
=
__hmul
(
dl
,
__int2half_rn
((
int8_t
)((
q
[
l
]
>>
shift
)
&
3
)
-
((
hm
[
l
]
&
m
)
?
0
:
4
)));
for
(
int
l
=
l0
;
l
<
l0
+
4
;
++
l
)
{
y
[
l
]
=
convert_from_half
<
dst_t
>
(
__hmul
(
dl
,
__int2half_rn
((
int8_t
)((
q
[
l
]
>>
shift
)
&
3
)
-
((
hm
[
l
]
&
m
)
?
0
:
4
))));
}
}
static
inline
__device__
void
get_scale_min_k4
(
int
j
,
const
uint8_t
*
q
,
uint8_t
&
d
,
uint8_t
&
m
)
{
...
...
@@ -188,8 +190,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
const
half
d2
=
__hmul
(
dall
,
__int2half_rn
(
sc
));
const
half
m2
=
__hmul
(
dmin
,
__int2half_rn
(
m
));
for
(
int
l
=
0
;
l
<
n
;
++
l
)
{
y
[
l
+
0
]
=
__hsub
(
__hmul
(
d1
,
__int2half_rn
(
q
[
l
]
&
0xF
)),
m1
);
y
[
l
+
32
]
=
__hsub
(
__hmul
(
d2
,
__int2half_rn
(
q
[
l
]
>>
4
)),
m2
);
y
[
l
+
0
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
d1
,
__int2half_rn
(
q
[
l
]
&
0xF
)),
m1
)
)
;
y
[
l
+
32
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
d2
,
__int2half_rn
(
q
[
l
]
>>
4
)),
m2
)
)
;
}
}
...
...
@@ -220,11 +222,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
const
half
d2
=
__hmul
(
dall
,
__int2half_rn
(
sc
));
const
half
m2
=
__hmul
(
dmin
,
__int2half_rn
(
m
));
uint8_t
hm
=
1
<<
(
2
*
il
);
y
[
0
]
=
__hsub
(
__hmul
(
d1
,
__int2half_rn
((
ql
[
0
]
&
0xF
)
+
(
qh
[
0
]
&
hm
?
16
:
0
))),
m1
);
y
[
1
]
=
__hsub
(
__hmul
(
d1
,
__int2half_rn
((
ql
[
1
]
&
0xF
)
+
(
qh
[
1
]
&
hm
?
16
:
0
))),
m1
);
y
[
0
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
d1
,
__int2half_rn
((
ql
[
0
]
&
0xF
)
+
(
qh
[
0
]
&
hm
?
16
:
0
))),
m1
)
)
;
y
[
1
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
d1
,
__int2half_rn
((
ql
[
1
]
&
0xF
)
+
(
qh
[
1
]
&
hm
?
16
:
0
))),
m1
)
)
;
hm
<<=
1
;
y
[
32
]
=
__hsub
(
__hmul
(
d2
,
__int2half_rn
((
ql
[
0
]
>>
4
)
+
(
qh
[
0
]
&
hm
?
16
:
0
))),
m2
);
y
[
33
]
=
__hsub
(
__hmul
(
d2
,
__int2half_rn
((
ql
[
1
]
>>
4
)
+
(
qh
[
1
]
&
hm
?
16
:
0
))),
m2
);
y
[
32
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
d2
,
__int2half_rn
((
ql
[
0
]
>>
4
)
+
(
qh
[
0
]
&
hm
?
16
:
0
))),
m2
)
)
;
y
[
33
]
=
convert_from_half
<
dst_t
>
(
__hsub
(
__hmul
(
d2
,
__int2half_rn
((
ql
[
1
]
>>
4
)
+
(
qh
[
1
]
&
hm
?
16
:
0
))),
m2
)
)
;
}
template
<
typename
dst_t
>
...
...
@@ -247,10 +249,10 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
const
uint8_t
qh
=
x
[
i
].
qh
[
32
*
ip
+
il
];
const
int8_t
*
sc
=
x
[
i
].
scales
+
is
;
y
[
0
]
=
__hmul
(
d
,
__int2half_rn
(
sc
[
0
]
*
((
int8_t
)((
ql
[
0
]
&
0xF
)
|
(((
qh
>>
0
)
&
3
)
<<
4
))
-
32
)));
y
[
32
]
=
__hmul
(
d
,
__int2half_rn
(
sc
[
2
]
*
((
int8_t
)((
ql
[
32
]
&
0xF
)
|
(((
qh
>>
2
)
&
3
)
<<
4
))
-
32
)));
y
[
64
]
=
__hmul
(
d
,
__int2half_rn
(
sc
[
4
]
*
((
int8_t
)((
ql
[
0
]
>>
4
)
|
(((
qh
>>
4
)
&
3
)
<<
4
))
-
32
)));
y
[
96
]
=
__hmul
(
d
,
__int2half_rn
(
sc
[
6
]
*
((
int8_t
)((
ql
[
32
]
>>
4
)
|
(((
qh
>>
6
)
&
3
)
<<
4
))
-
32
)));
y
[
0
]
=
convert_from_half
<
dst_t
>
(
__hmul
(
d
,
__int2half_rn
(
sc
[
0
]
*
((
int8_t
)((
ql
[
0
]
&
0xF
)
|
(((
qh
>>
0
)
&
3
)
<<
4
))
-
32
)))
)
;
y
[
32
]
=
convert_from_half
<
dst_t
>
(
__hmul
(
d
,
__int2half_rn
(
sc
[
2
]
*
((
int8_t
)((
ql
[
32
]
&
0xF
)
|
(((
qh
>>
2
)
&
3
)
<<
4
))
-
32
)))
)
;
y
[
64
]
=
convert_from_half
<
dst_t
>
(
__hmul
(
d
,
__int2half_rn
(
sc
[
4
]
*
((
int8_t
)((
ql
[
0
]
>>
4
)
|
(((
qh
>>
4
)
&
3
)
<<
4
))
-
32
)))
)
;
y
[
96
]
=
convert_from_half
<
dst_t
>
(
__hmul
(
d
,
__int2half_rn
(
sc
[
6
]
*
((
int8_t
)((
ql
[
32
]
>>
4
)
|
(((
qh
>>
6
)
&
3
)
<<
4
))
-
32
)))
)
;
}
template
<
typename
dst_t
>
...
...
@@ -269,7 +271,7 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
const
uint32_t
aux32
=
q2
[
2
]
|
(
q2
[
3
]
<<
16
);
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
(
aux32
>>
28
))
*
0.25
f
;
const
uint8_t
signs
=
ksigns_iq2xs
[(
aux32
>>
7
*
il
)
&
127
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
y
[
j
]
=
__float2half
(
d
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1.
f
:
1.
f
)
)
;
for
(
int
j
=
0
;
j
<
8
;
++
j
)
y
[
j
]
=
d
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1.
f
:
1.
f
);
}
template
<
typename
dst_t
>
...
...
@@ -286,7 +288,7 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
const
uint8_t
*
grid
=
(
const
uint8_t
*
)(
iq2xs_grid
+
(
q2
[
il
]
&
511
));
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
((
x
[
i
].
scales
[
ib
]
>>
4
*
(
il
/
2
))
&
0xf
))
*
0.25
f
;
const
uint8_t
signs
=
ksigns_iq2xs
[
q2
[
il
]
>>
9
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
y
[
j
]
=
__float2half
(
d
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1.
f
:
1.
f
)
)
;
for
(
int
j
=
0
;
j
<
8
;
++
j
)
y
[
j
]
=
d
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1.
f
:
1.
f
);
}
...
...
@@ -303,7 +305,7 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
const
uint8_t
*
grid
=
(
const
uint8_t
*
)(
iq2s_grid
+
(
x
[
i
].
qs
[
4
*
ib
+
il
]
|
((
x
[
i
].
qh
[
ib
]
<<
(
8
-
2
*
il
))
&
0x300
)));
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
((
x
[
i
].
scales
[
ib
]
>>
4
*
(
il
/
2
))
&
0xf
))
*
0.25
f
;
const
uint8_t
signs
=
x
[
i
].
qs
[
QK_K
/
8
+
4
*
ib
+
il
];
for
(
int
j
=
0
;
j
<
8
;
++
j
)
y
[
j
]
=
__float2half
(
d
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1.
f
:
1.
f
)
)
;
for
(
int
j
=
0
;
j
<
8
;
++
j
)
y
[
j
]
=
d
*
grid
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
]
?
-
1.
f
:
1.
f
);
}
template
<
typename
dst_t
>
...
...
@@ -324,8 +326,8 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
(
aux32
>>
28
))
*
0.5
f
;
const
uint8_t
signs
=
ksigns_iq2xs
[(
aux32
>>
7
*
il
)
&
127
];
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
y
[
j
+
0
]
=
__float2half
(
d
*
grid1
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
0
]
?
-
1.
f
:
1.
f
)
)
;
y
[
j
+
4
]
=
__float2half
(
d
*
grid2
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
4
]
?
-
1.
f
:
1.
f
)
)
;
y
[
j
+
0
]
=
d
*
grid1
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
0
]
?
-
1.
f
:
1.
f
);
y
[
j
+
4
]
=
d
*
grid2
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
4
]
?
-
1.
f
:
1.
f
);
}
}
...
...
@@ -345,8 +347,8 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
(
0.5
f
+
((
x
[
i
].
scales
[
ib
/
2
]
>>
4
*
(
ib
%
2
))
&
0xf
))
*
0.5
f
;
const
uint8_t
signs
=
x
[
i
].
signs
[
4
*
ib
+
il
];
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
y
[
j
+
0
]
=
__float2half
(
d
*
grid1
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
0
]
?
-
1.
f
:
1.
f
)
)
;
y
[
j
+
4
]
=
__float2half
(
d
*
grid2
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
4
]
?
-
1.
f
:
1.
f
)
)
;
y
[
j
+
0
]
=
d
*
grid1
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
0
]
?
-
1.
f
:
1.
f
);
y
[
j
+
4
]
=
d
*
grid2
[
j
]
*
(
signs
&
kmask_iq2xs
[
j
+
4
]
?
-
1.
f
:
1.
f
);
}
}
...
...
@@ -367,7 +369,7 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
grid32
[
1
]
=
(
grid32
[
0
]
>>
4
)
&
0x0f0f0f0f
;
grid32
[
0
]
&=
0x0f0f0f0f
;
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
y
[
j
]
=
__float2half
(
d
*
(
q
[
j
]
+
delta
)
)
;
y
[
j
]
=
d
*
(
q
[
j
]
+
delta
);
}
}
...
...
@@ -392,7 +394,7 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
grid32
[
1
]
=
(
grid32
[
0
]
>>
4
)
&
0x0f0f0f0f
;
grid32
[
0
]
&=
0x0f0f0f0f
;
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
y
[
j
]
=
__float2half
(
d
*
(
q
[
j
]
+
delta
)
)
;
y
[
j
]
=
d
*
(
q
[
j
]
+
delta
);
}
}
...
...
@@ -409,8 +411,8 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
const
uint8_t
*
q4
=
x
[
ib
].
qs
+
4
*
il
;
const
float
d
=
__half2float
(
x
[
ib
].
d
);
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
y
[
j
+
0
]
=
__float2half
(
d
*
kvalues_iq4nl
[
q4
[
j
]
&
0xf
]
)
;
y
[
j
+
16
]
=
__float2half
(
d
*
kvalues_iq4nl
[
q4
[
j
]
>>
4
]
)
;
y
[
j
+
0
]
=
d
*
kvalues_iq4nl
[
q4
[
j
]
&
0xf
];
y
[
j
+
16
]
=
d
*
kvalues_iq4nl
[
q4
[
j
]
>>
4
];
}
}
...
...
@@ -427,8 +429,8 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
const
uint8_t
*
q4
=
x
[
i
].
qs
+
16
*
ib
+
4
*
il
;
const
float
d
=
__half2float
(
x
[
i
].
d
)
*
((((
x
[
i
].
scales_l
[
ib
/
2
]
>>
4
*
(
ib
%
2
))
&
0xf
)
|
(((
x
[
i
].
scales_h
>>
2
*
ib
)
&
3
)
<<
4
))
-
32
);
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
y
[
j
+
0
]
=
__float2half
(
d
*
kvalues_iq4nl
[
q4
[
j
]
&
0xf
]
)
;
y
[
j
+
16
]
=
__float2half
(
d
*
kvalues_iq4nl
[
q4
[
j
]
>>
4
]
)
;
y
[
j
+
0
]
=
d
*
kvalues_iq4nl
[
q4
[
j
]
&
0xf
];
y
[
j
+
16
]
=
d
*
kvalues_iq4nl
[
q4
[
j
]
>>
4
];
}
}
...
...
@@ -522,7 +524,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
dequantize_block_iq4_xs
<<<
nb
,
32
,
0
,
stream
>>>
(
vx
,
y
);
}
static
to_fp16_cuda_t
ggml_get_to_fp16_cuda
(
int64_t
type
)
{
template
<
typename
dst_t
>
static
to_cuda_ggml_t
<
dst_t
>
ggml_get_to_cuda
(
int64_t
type
)
{
switch
(
type
)
{
case
2
:
return
dequantize_block_cuda
<
QK4_0
,
QR4_0
,
dequantize_q4_0
>
;
...
...
csrc/quantization/gguf/ggml-common.h
View file @
fcfc474d
...
...
@@ -1063,7 +1063,8 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -
typedef
half
dfloat
;
// dequantize float
typedef
half2
dfloat2
;
typedef
void
(
*
dequantize_kernel_t
)(
const
void
*
vx
,
const
int
ib
,
const
int
iqs
,
dfloat2
&
v
);
typedef
void
(
*
to_fp16_cuda_t
)(
const
void
*
__restrict__
x
,
dfloat
*
__restrict__
y
,
int
k
,
cudaStream_t
stream
);
template
<
typename
dst_t
>
using
to_cuda_ggml_t
=
void
(
*
)(
const
void
*
__restrict__
x
,
dst_t
*
__restrict__
y
,
int
k
,
cudaStream_t
stream
);
typedef
float
(
*
vec_dot_q_cuda_t
)(
const
void
*
__restrict__
vbq
,
const
block_q8_1
*
__restrict__
bq8_1
,
const
int
&
iqs
);
typedef
void
(
*
allocate_tiles_cuda_t
)(
int
**
x_ql
,
half2
**
x_dm
,
int
**
x_qh
,
int
**
x_sc
);
typedef
void
(
*
load_tiles_cuda_t
)(
...
...
@@ -1075,6 +1076,25 @@ typedef float (*vec_dot_q_mul_mat_cuda_t)(
// Utility function
template
<
typename
dst_t
>
static
__device__
__forceinline__
dst_t
convert_from_half
(
half
val
)
{
return
val
;
}
template
<
>
__device__
__forceinline__
c10
::
BFloat16
convert_from_half
<
c10
::
BFloat16
>
(
half
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return
__float2bfloat16
(
__half2float
(
val
));
#else
return
__half2float
(
val
);
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
}
template
<
>
__device__
__forceinline__
float
convert_from_half
<
float
>
(
half
val
)
{
return
__half2float
(
val
);
}
#if defined(USE_ROCM)
#ifndef __has_builtin
...
...
csrc/quantization/gguf/gguf_kernel.cu
View file @
fcfc474d
...
...
@@ -71,14 +71,19 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
}
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
// quant weight
int64_t
type
,
int64_t
m
,
int64_t
n
)
{
int64_t
type
,
int64_t
m
,
int64_t
n
,
std
::
optional
<
at
::
ScalarType
>
const
&
dtype
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
W
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kFloat16
).
device
(
W
.
device
());
auto
dtype_
=
dtype
.
value_or
(
torch
::
kFloat16
);
auto
options
=
torch
::
TensorOptions
().
dtype
(
dtype_
).
device
(
W
.
device
());
at
::
Tensor
DW
=
torch
::
empty
({
m
,
n
},
options
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
to_fp16_cuda_t
to_fp16_cuda
=
ggml_get_to_fp16_cuda
(
type
);
to_fp16_cuda
((
void
*
)
W
.
data_ptr
(),
(
half
*
)
DW
.
data_ptr
(),
m
*
n
,
stream
);
VLLM_DISPATCH_FLOATING_TYPES
(
DW
.
scalar_type
(),
"ggml_dequantize"
,
[
&
]
{
auto
to_cuda
=
ggml_get_to_cuda
<
scalar_t
>
(
type
);
to_cuda
((
void
*
)
W
.
data_ptr
(),
(
scalar_t
*
)
DW
.
data_ptr
(),
m
*
n
,
stream
);
});
return
DW
;
}
...
...
@@ -375,25 +380,25 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input
int64_t
ggml_moe_get_block_size
(
int64_t
type
)
{
switch
(
type
)
{
case
2
:
return
M
MQ
_X_Q4_0
;
return
M
OE
_X_Q4_0
;
case
3
:
return
M
MQ
_X_Q4_1
;
return
M
OE
_X_Q4_1
;
case
6
:
return
M
MQ
_X_Q5_0
;
return
M
OE
_X_Q5_0
;
case
7
:
return
M
MQ
_X_Q5_1
;
return
M
OE
_X_Q5_1
;
case
8
:
return
M
MQ
_X_Q8_0
;
return
M
OE
_X_Q8_0
;
case
10
:
return
M
MQ
_X_Q2_K
;
return
M
OE
_X_Q2_K
;
case
11
:
return
M
MQ
_X_Q3_K
;
return
M
OE
_X_Q3_K
;
case
12
:
return
M
MQ
_X_Q4_K
;
return
M
OE
_X_Q4_K
;
case
13
:
return
M
MQ
_X_Q5_K
;
return
M
OE
_X_Q5_K
;
case
14
:
return
M
MQ
_X_Q6_K
;
return
M
OE
_X_Q6_K
;
}
return
0
;
}
csrc/quantization/gguf/moe.cuh
View file @
fcfc474d
...
...
@@ -129,12 +129,12 @@ static __device__ __forceinline__ void moe_q(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q4_0 64
#define M
MQ
_Y_Q4_0 128
#define M
OE
_X_Q4_0 64
#define M
OE
_Y_Q4_0 128
#define NWARPS_Q4_0 8
#else
#define M
MQ
_X_Q4_0 4
#define M
MQ
_Y_Q4_0 32
#define M
OE
_X_Q4_0 4
#define M
OE
_Y_Q4_0 32
#define NWARPS_Q4_0 4
#endif
...
...
@@ -149,8 +149,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_0, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q4_0
;
const
int
mmq_y
=
M
MQ
_Y_Q4_0
;
const
int
mmq_x
=
M
OE
_X_Q4_0
;
const
int
mmq_y
=
M
OE
_Y_Q4_0
;
const
int
nwarps
=
NWARPS_Q4_0
;
moe_q
<
scalar_t
,
QK4_0
,
QR4_0
,
QI4_0
,
true
,
block_q4_0
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -167,8 +167,8 @@ static void ggml_moe_q4_0_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
int
mmq_x
=
M
MQ
_X_Q4_0
;
int
mmq_y
=
M
MQ
_Y_Q4_0
;
int
mmq_x
=
M
OE
_X_Q4_0
;
int
mmq_y
=
M
OE
_Y_Q4_0
;
int
nwarps
=
NWARPS_Q4_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -190,12 +190,12 @@ static void ggml_moe_q4_0_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q4_1 64
#define M
MQ
_Y_Q4_1 128
#define M
OE
_X_Q4_1 64
#define M
OE
_Y_Q4_1 128
#define NWARPS_Q4_1 8
#else
#define M
MQ
_X_Q4_1 4
#define M
MQ
_Y_Q4_1 32
#define M
OE
_X_Q4_1 4
#define M
OE
_Y_Q4_1 32
#define NWARPS_Q4_1 4
#endif
...
...
@@ -210,8 +210,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_1, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q4_1
;
const
int
mmq_y
=
M
MQ
_Y_Q4_1
;
const
int
mmq_x
=
M
OE
_X_Q4_1
;
const
int
mmq_y
=
M
OE
_Y_Q4_1
;
const
int
nwarps
=
NWARPS_Q4_1
;
moe_q
<
scalar_t
,
QK4_1
,
QR4_1
,
QI4_1
,
true
,
block_q4_1
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -228,8 +228,8 @@ static void ggml_moe_q4_1_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
int
mmq_x
=
M
MQ
_X_Q4_1
;
int
mmq_y
=
M
MQ
_Y_Q4_1
;
int
mmq_x
=
M
OE
_X_Q4_1
;
int
mmq_y
=
M
OE
_Y_Q4_1
;
int
nwarps
=
NWARPS_Q4_1
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -251,12 +251,12 @@ static void ggml_moe_q4_1_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q5_0 64
#define M
MQ
_Y_Q5_0 128
#define M
OE
_X_Q5_0 64
#define M
OE
_Y_Q5_0 128
#define NWARPS_Q5_0 8
#else
#define M
MQ
_X_Q5_0 4
#define M
MQ
_Y_Q5_0 32
#define M
OE
_X_Q5_0 4
#define M
OE
_Y_Q5_0 32
#define NWARPS_Q5_0 4
#endif
...
...
@@ -271,8 +271,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_0, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q5_0
;
const
int
mmq_y
=
M
MQ
_Y_Q5_0
;
const
int
mmq_x
=
M
OE
_X_Q5_0
;
const
int
mmq_y
=
M
OE
_Y_Q5_0
;
const
int
nwarps
=
NWARPS_Q5_0
;
moe_q
<
scalar_t
,
QK5_0
,
QR5_0
,
QI5_0
,
false
,
block_q5_0
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -289,8 +289,8 @@ static void ggml_moe_q5_0_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q5_0
;
const
int
mmq_y
=
M
MQ
_Y_Q5_0
;
const
int
mmq_x
=
M
OE
_X_Q5_0
;
const
int
mmq_y
=
M
OE
_Y_Q5_0
;
const
int
nwarps
=
NWARPS_Q5_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -312,12 +312,12 @@ static void ggml_moe_q5_0_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q5_1 64
#define M
MQ
_Y_Q5_1 128
#define M
OE
_X_Q5_1 64
#define M
OE
_Y_Q5_1 128
#define NWARPS_Q5_1 8
#else
#define M
MQ
_X_Q5_1 4
#define M
MQ
_Y_Q5_1 32
#define M
OE
_X_Q5_1 4
#define M
OE
_Y_Q5_1 32
#define NWARPS_Q5_1 4
#endif
...
...
@@ -332,8 +332,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_1, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q5_1
;
const
int
mmq_y
=
M
MQ
_Y_Q5_1
;
const
int
mmq_x
=
M
OE
_X_Q5_1
;
const
int
mmq_y
=
M
OE
_Y_Q5_1
;
const
int
nwarps
=
NWARPS_Q5_1
;
moe_q
<
scalar_t
,
QK5_1
,
QR5_1
,
QI5_1
,
true
,
block_q5_1
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -350,8 +350,8 @@ static void ggml_moe_q5_1_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q5_1
;
const
int
mmq_y
=
M
MQ
_Y_Q5_1
;
const
int
mmq_x
=
M
OE
_X_Q5_1
;
const
int
mmq_y
=
M
OE
_Y_Q5_1
;
const
int
nwarps
=
NWARPS_Q5_1
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -373,12 +373,12 @@ static void ggml_moe_q5_1_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q8_0 64
#define M
MQ
_Y_Q8_0 128
#define M
OE
_X_Q8_0 64
#define M
OE
_Y_Q8_0 128
#define NWARPS_Q8_0 8
#else
#define M
MQ
_X_Q8_0 4
#define M
MQ
_Y_Q8_0 32
#define M
OE
_X_Q8_0 4
#define M
OE
_Y_Q8_0 32
#define NWARPS_Q8_0 4
#endif
...
...
@@ -393,8 +393,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q8_0, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q8_0
;
const
int
mmq_y
=
M
MQ
_Y_Q8_0
;
const
int
mmq_x
=
M
OE
_X_Q8_0
;
const
int
mmq_y
=
M
OE
_Y_Q8_0
;
const
int
nwarps
=
NWARPS_Q8_0
;
moe_q
<
scalar_t
,
QK8_0
,
QR8_0
,
QI8_0
,
false
,
block_q8_0
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -411,8 +411,8 @@ static void ggml_moe_q8_0_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q8_0
;
const
int
mmq_y
=
M
MQ
_Y_Q8_0
;
const
int
mmq_x
=
M
OE
_X_Q8_0
;
const
int
mmq_y
=
M
OE
_Y_Q8_0
;
const
int
nwarps
=
NWARPS_Q8_0
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -434,12 +434,12 @@ static void ggml_moe_q8_0_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q2_K 64
#define M
MQ
_Y_Q2_K 128
#define M
OE
_X_Q2_K 64
#define M
OE
_Y_Q2_K 128
#define NWARPS_Q2_K 8
#else
#define M
MQ
_X_Q2_K 4
#define M
MQ
_Y_Q2_K 32
#define M
OE
_X_Q2_K 4
#define M
OE
_Y_Q2_K 32
#define NWARPS_Q2_K 4
#endif
...
...
@@ -454,8 +454,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q2_K, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q2_K
;
const
int
mmq_y
=
M
MQ
_Y_Q2_K
;
const
int
mmq_x
=
M
OE
_X_Q2_K
;
const
int
mmq_y
=
M
OE
_Y_Q2_K
;
const
int
nwarps
=
NWARPS_Q2_K
;
moe_q
<
scalar_t
,
QK_K
,
QR2_K
,
QI2_K
,
false
,
block_q2_K
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -472,8 +472,8 @@ static void ggml_moe_q2_K_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q2_K
;
const
int
mmq_y
=
M
MQ
_Y_Q2_K
;
const
int
mmq_x
=
M
OE
_X_Q2_K
;
const
int
mmq_y
=
M
OE
_Y_Q2_K
;
const
int
nwarps
=
NWARPS_Q2_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -495,12 +495,12 @@ static void ggml_moe_q2_K_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q3_K 64
#define M
MQ
_Y_Q3_K 128
#define M
OE
_X_Q3_K 64
#define M
OE
_Y_Q3_K 128
#define NWARPS_Q3_K 8
#else
#define M
MQ
_X_Q3_K 4
#define M
MQ
_Y_Q3_K 32
#define M
OE
_X_Q3_K 4
#define M
OE
_Y_Q3_K 32
#define NWARPS_Q3_K 4
#endif
...
...
@@ -516,8 +516,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q3_K, 2)
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q3_K
;
const
int
mmq_y
=
M
MQ
_Y_Q3_K
;
const
int
mmq_x
=
M
OE
_X_Q3_K
;
const
int
mmq_y
=
M
OE
_Y_Q3_K
;
const
int
nwarps
=
NWARPS_Q3_K
;
moe_q
<
scalar_t
,
QK_K
,
QR3_K
,
QI3_K
,
false
,
block_q3_K
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -533,8 +533,8 @@ static void ggml_moe_q3_K_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q3_K
;
const
int
mmq_y
=
M
MQ
_Y_Q3_K
;
const
int
mmq_x
=
M
OE
_X_Q3_K
;
const
int
mmq_y
=
M
OE
_Y_Q3_K
;
const
int
nwarps
=
NWARPS_Q3_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -556,12 +556,12 @@ static void ggml_moe_q3_K_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q4_K 64
#define M
MQ
_Y_Q4_K 128
#define M
OE
_X_Q4_K 64
#define M
OE
_Y_Q4_K 128
#define NWARPS_Q4_K 8
#else
#define M
MQ
_X_Q4_K 4
#define M
MQ
_Y_Q4_K 32
#define M
OE
_X_Q4_K 4
#define M
OE
_Y_Q4_K 32
#define NWARPS_Q4_K 4
#endif
...
...
@@ -576,8 +576,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q4_K, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q4_K
;
const
int
mmq_y
=
M
MQ
_Y_Q4_K
;
const
int
mmq_x
=
M
OE
_X_Q4_K
;
const
int
mmq_y
=
M
OE
_Y_Q4_K
;
const
int
nwarps
=
NWARPS_Q4_K
;
moe_q
<
scalar_t
,
QK_K
,
QR4_K
,
QI4_K
,
true
,
block_q4_K
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -594,8 +594,8 @@ static void ggml_moe_q4_K_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q4_K
;
const
int
mmq_y
=
M
MQ
_Y_Q4_K
;
const
int
mmq_x
=
M
OE
_X_Q4_K
;
const
int
mmq_y
=
M
OE
_Y_Q4_K
;
const
int
nwarps
=
NWARPS_Q4_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -617,12 +617,12 @@ static void ggml_moe_q4_K_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q5_K 64
#define M
MQ
_Y_Q5_K 128
#define M
OE
_X_Q5_K 64
#define M
OE
_Y_Q5_K 128
#define NWARPS_Q5_K 8
#else
#define M
MQ
_X_Q5_K 4
#define M
MQ
_Y_Q5_K 32
#define M
OE
_X_Q5_K 4
#define M
OE
_Y_Q5_K 32
#define NWARPS_Q5_K 4
#endif
...
...
@@ -637,8 +637,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q5_K, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q5_K
;
const
int
mmq_y
=
M
MQ
_Y_Q5_K
;
const
int
mmq_x
=
M
OE
_X_Q5_K
;
const
int
mmq_y
=
M
OE
_Y_Q5_K
;
const
int
nwarps
=
NWARPS_Q5_K
;
moe_q
<
scalar_t
,
QK_K
,
QR5_K
,
QI5_K
,
true
,
block_q5_K
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -655,8 +655,8 @@ static void ggml_moe_q5_K_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q5_K
;
const
int
mmq_y
=
M
MQ
_Y_Q5_K
;
const
int
mmq_x
=
M
OE
_X_Q5_K
;
const
int
mmq_y
=
M
OE
_Y_Q5_K
;
const
int
nwarps
=
NWARPS_Q5_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
@@ -678,12 +678,12 @@ static void ggml_moe_q5_K_q8_1_cuda(
}
#if defined(USE_ROCM)
#define M
MQ
_X_Q6_K 64
#define M
MQ
_Y_Q6_K 128
#define M
OE
_X_Q6_K 64
#define M
OE
_Y_Q6_K 128
#define NWARPS_Q6_K 8
#else
#define M
MQ
_X_Q6_K 4
#define M
MQ
_Y_Q6_K 32
#define M
OE
_X_Q6_K 4
#define M
OE
_Y_Q6_K 32
#define NWARPS_Q6_K 4
#endif
...
...
@@ -698,8 +698,8 @@ __launch_bounds__(WARP_SIZE_GGUF* NWARPS_Q6_K, 2)
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
)
{
const
int
mmq_x
=
M
MQ
_X_Q6_K
;
const
int
mmq_y
=
M
MQ
_Y_Q6_K
;
const
int
mmq_x
=
M
OE
_X_Q6_K
;
const
int
mmq_y
=
M
OE
_Y_Q6_K
;
const
int
nwarps
=
NWARPS_Q6_K
;
moe_q
<
scalar_t
,
QK_K
,
QR6_K
,
QI6_K
,
false
,
block_q6_K
,
mmq_x
,
mmq_y
,
nwarps
,
...
...
@@ -716,8 +716,8 @@ static void ggml_moe_q6_K_q8_1_cuda(
const
int
exp_stride
,
const
int
ncols_x
,
const
int
nrows_x
,
const
int
ncols_y
,
const
int
nrows_y
,
const
int
nrows_dst
,
const
int
top_k
,
const
int
tokens_post_padded
,
cudaStream_t
stream
)
{
const
int
mmq_x
=
M
MQ
_X_Q6_K
;
const
int
mmq_y
=
M
MQ
_Y_Q6_K
;
const
int
mmq_x
=
M
OE
_X_Q6_K
;
const
int
mmq_y
=
M
OE
_Y_Q6_K
;
const
int
nwarps
=
NWARPS_Q6_K
;
const
int
block_num_x
=
(
nrows_x
+
mmq_y
-
1
)
/
mmq_y
;
...
...
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
View file @
fcfc474d
...
...
@@ -14,7 +14,7 @@ __global__ void awq_marlin_repack_kernel(
int
n_tiles
=
size_n
/
tile_n_size
;
int
block_k_tiles
=
div_ceil
(
k_tiles
,
gridDim
.
x
);
int
start_k_tile
=
blockIdx
.
x
*
block_k_tiles
;
auto
start_k_tile
=
blockIdx
.
x
*
block_k_tiles
;
if
(
start_k_tile
>=
k_tiles
)
{
return
;
}
...
...
@@ -51,8 +51,8 @@ __global__ void awq_marlin_repack_kernel(
int4
*
sh_ptr
=
sh
+
stage_size
*
pipe
;
if
(
threadIdx
.
x
<
stage_size
)
{
int
k_id
=
threadIdx
.
x
/
stage_n_threads
;
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
auto
k_id
=
threadIdx
.
x
/
stage_n_threads
;
auto
n_id
=
threadIdx
.
x
%
stage_n_threads
;
int
first_k
=
k_tile_id
*
tile_k_size
;
...
...
@@ -70,8 +70,8 @@ __global__ void awq_marlin_repack_kernel(
return
;
}
int
warp_id
=
threadIdx
.
x
/
32
;
int
th_id
=
threadIdx
.
x
%
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
auto
th_id
=
threadIdx
.
x
%
32
;
if
(
warp_id
>=
4
)
{
return
;
...
...
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
fcfc474d
...
...
@@ -460,7 +460,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int
const
*
__restrict__
perm_int_ptr
,
int4
*
__restrict__
out_int4_ptr
,
int
size_m
,
int
size_k
,
int
lda
,
int
block_rows
)
{
int
start_row
=
block_rows
*
blockIdx
.
x
;
auto
start_row
=
block_rows
*
blockIdx
.
x
;
int
finish_row
=
start_row
+
block_rows
;
if
(
finish_row
>
size_m
)
{
finish_row
=
size_m
;
...
...
@@ -484,7 +484,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int
base_k
=
0
;
for
(
int
i
=
0
;
i
<
iters
;
i
++
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
auto
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
...
...
@@ -494,7 +494,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
if
(
rest
)
{
if
(
threadIdx
.
x
<
rest
)
{
int
cur_k
=
base_k
+
threadIdx
.
x
;
auto
cur_k
=
base_k
+
threadIdx
.
x
;
int
src_pos
=
perm_int_ptr
[
cur_k
];
out_half
[
cur_k
]
=
a_row_half
[
src_pos
];
...
...
@@ -723,8 +723,8 @@ __global__ void Marlin(
(
threadIdx
.
x
%
b_sh_stride_threads
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
int
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
auto
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
auto
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
// For act_order
constexpr
int
k_iter_size
=
tb_k
/
b_sh_wr_iters
;
...
...
@@ -743,7 +743,7 @@ __global__ void Marlin(
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
s_sh_wr
=
threadIdx
.
x
;
auto
s_sh_wr
=
threadIdx
.
x
;
bool
s_sh_wr_pred
=
threadIdx
.
x
<
s_sh_stride
;
// Zero-points
...
...
@@ -756,7 +756,7 @@ __global__ void Marlin(
zp_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
}
int
zp_sh_wr
=
threadIdx
.
x
;
auto
zp_sh_wr
=
threadIdx
.
x
;
bool
zp_sh_wr_pred
=
threadIdx
.
x
<
zp_sh_stride
;
// We use a different scale layout for grouped and column-wise quantization as
...
...
@@ -1047,7 +1047,7 @@ __global__ void Marlin(
int4
*
sh_s_stage
=
sh_s
+
s_sh_stage
*
pipe
;
reinterpret_cast
<
int4
*>
(
&
frag_s
[
k
%
2
])[
0
]
=
sh_s_stage
[
s_sh_rd
];
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
...
...
@@ -1085,7 +1085,7 @@ __global__ void Marlin(
// Determine "position" inside the thread-block (based on warp and
// thread-id)
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
// Each warp processes 4 16-size tiles over N
...
...
@@ -1094,7 +1094,7 @@ __global__ void Marlin(
cur_k
+=
warp_row
*
16
;
int
th_id
=
threadIdx
.
x
%
32
;
auto
th_id
=
threadIdx
.
x
%
32
;
cur_k
+=
(
th_id
%
4
)
*
2
;
// Due to tensor-core layout for fp16 B matrix
int
s_col_shift
=
...
...
@@ -1159,7 +1159,7 @@ __global__ void Marlin(
(
reinterpret_cast
<
int
*>
(
sh_zp_stage
))[
zp_sh_rd
+
i
];
}
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
...
...
@@ -1197,7 +1197,7 @@ __global__ void Marlin(
(
pipe
/
(
group_blocks
/
thread_k_blocks
)));
reinterpret_cast
<
int4
*>
(
&
frag_zpf
[
k
%
2
])[
0
]
=
sh_zp_stage
[
zp_sh_rd
];
}
else
{
int
warp_id
=
threadIdx
.
x
/
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
int
n_warps
=
thread_n_blocks
/
4
;
int
warp_row
=
warp_id
/
n_warps
;
...
...
@@ -1323,7 +1323,7 @@ __global__ void Marlin(
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride_threads
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
auto
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
...
...
@@ -1390,7 +1390,7 @@ __global__ void Marlin(
4
*
(
threadIdx
.
x
/
32
)
+
threadIdx
.
x
%
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
;
int
c_sh_wr
=
threadIdx
.
x
;
auto
c_sh_wr
=
threadIdx
.
x
;
int
row
=
(
threadIdx
.
x
%
32
)
/
4
;
...
...
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
View file @
fcfc474d
...
...
@@ -15,7 +15,7 @@ __global__ void gptq_marlin_repack_kernel(
int
n_tiles
=
size_n
/
tile_n_size
;
int
block_k_tiles
=
div_ceil
(
k_tiles
,
gridDim
.
x
);
int
start_k_tile
=
blockIdx
.
x
*
block_k_tiles
;
auto
start_k_tile
=
blockIdx
.
x
*
block_k_tiles
;
if
(
start_k_tile
>=
k_tiles
)
{
return
;
}
...
...
@@ -71,8 +71,8 @@ __global__ void gptq_marlin_repack_kernel(
if
constexpr
(
has_perm
)
{
if
(
threadIdx
.
x
<
stage_size
)
{
int
k_id
=
threadIdx
.
x
/
stage_n_threads
;
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
auto
k_id
=
threadIdx
.
x
/
stage_n_threads
;
auto
n_id
=
threadIdx
.
x
%
stage_n_threads
;
uint32_t
const
*
sh_perm_int_ptr
=
reinterpret_cast
<
uint32_t
const
*>
(
sh_perm_ptr
);
...
...
@@ -88,8 +88,8 @@ __global__ void gptq_marlin_repack_kernel(
}
else
{
if
(
threadIdx
.
x
<
stage_size
)
{
int
k_id
=
threadIdx
.
x
/
stage_n_threads
;
int
n_id
=
threadIdx
.
x
%
stage_n_threads
;
auto
k_id
=
threadIdx
.
x
/
stage_n_threads
;
auto
n_id
=
threadIdx
.
x
%
stage_n_threads
;
int
first_k
=
k_tile_id
*
tile_k_size
;
int
first_k_packed
=
first_k
/
pack_factor
;
...
...
@@ -109,8 +109,8 @@ __global__ void gptq_marlin_repack_kernel(
return
;
}
int
warp_id
=
threadIdx
.
x
/
32
;
int
th_id
=
threadIdx
.
x
%
32
;
auto
warp_id
=
threadIdx
.
x
/
32
;
auto
th_id
=
threadIdx
.
x
%
32
;
if
(
warp_id
>=
4
)
{
return
;
...
...
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
View file @
fcfc474d
...
...
@@ -277,12 +277,12 @@ __global__ void Marlin(
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride
)
+
(
threadIdx
.
x
%
b_sh_stride
);
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
;
int
b_sh_rd
=
threadIdx
.
x
;
auto
b_sh_wr
=
threadIdx
.
x
;
auto
b_sh_rd
=
threadIdx
.
x
;
int
s_gl_rd
=
s_gl_stride
*
((
thread_k_blocks
*
slice_row
)
/
group_blocks
)
+
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
int
s_sh_wr
=
threadIdx
.
x
;
auto
s_sh_wr
=
threadIdx
.
x
;
int
s_sh_rd
;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
...
...
@@ -455,7 +455,7 @@ __global__ void Marlin(
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride
;
auto
red_idx
=
threadIdx
.
x
/
b_sh_stride
;
constexpr
int
red_sh_stride
=
b_sh_stride
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride
)
+
...
...
@@ -522,7 +522,7 @@ __global__ void Marlin(
4
*
(
threadIdx
.
x
/
32
)
+
threadIdx
.
x
%
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
;
int
c_sh_wr
=
threadIdx
.
x
;
auto
c_sh_wr
=
threadIdx
.
x
;
int
row
=
(
threadIdx
.
x
%
32
)
/
4
;
...
...
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
View file @
fcfc474d
...
...
@@ -353,10 +353,10 @@ __global__ void Marlin(
b_gl_stride
*
(
threadIdx
.
x
/
b_sh_stride
)
+
(
threadIdx
.
x
%
b_sh_stride
);
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
;
int
b_sh_rd
=
threadIdx
.
x
;
auto
b_sh_wr
=
threadIdx
.
x
;
auto
b_sh_rd
=
threadIdx
.
x
;
int
s_tok_gl_rd
=
threadIdx
.
x
;
auto
s_tok_gl_rd
=
threadIdx
.
x
;
// NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10,
// 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for
// thread 0, 1, 2, 3. For more details, refer to mma operand A layout as
...
...
@@ -368,8 +368,8 @@ __global__ void Marlin(
int
s_tok_sh_rd
=
(
threadIdx
.
x
%
32
)
/
4
;
bool
s_tok_sh_wr_pred
=
threadIdx
.
x
<
prob_m
;
int
s_ch_gl_rd
=
s_ch_sh_stride
*
slice_col
+
threadIdx
.
x
;
int
s_ch_sh_wr
=
threadIdx
.
x
;
auto
s_ch_gl_rd
=
s_ch_sh_stride
*
slice_col
+
threadIdx
.
x
;
auto
s_ch_sh_wr
=
threadIdx
.
x
;
int
s_ch_sh_rd
=
16
*
((
threadIdx
.
x
/
32
)
%
(
thread_n_blocks
/
4
))
+
2
*
((
threadIdx
.
x
%
32
)
%
4
);
bool
s_ch_sh_wr_pred
=
threadIdx
.
x
<
s_ch_sh_stride
;
...
...
@@ -558,7 +558,7 @@ __global__ void Marlin(
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride
;
auto
red_idx
=
threadIdx
.
x
/
b_sh_stride
;
constexpr
int
red_sh_stride
=
b_sh_stride
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride
)
+
...
...
@@ -628,7 +628,7 @@ __global__ void Marlin(
8
*
(
threadIdx
.
x
/
32
)
+
(
threadIdx
.
x
%
4
)
*
2
;
c_gl_wr
+=
(
4
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
*
2
;
int
c_sh_wr
=
2
*
threadIdx
.
x
;
auto
c_sh_wr
=
2
*
threadIdx
.
x
;
int
row
=
(
threadIdx
.
x
%
32
)
/
4
;
...
...
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
View file @
fcfc474d
...
...
@@ -273,15 +273,15 @@ __global__ void Marlin_24(
(
threadIdx
.
x
%
b_sh_stride_threads
)
*
b_thread_vecs
;
b_gl_rd
+=
b_sh_stride
*
slice_col
;
b_gl_rd
+=
b_gl_rd_delta_o
*
slice_row
;
int
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
int
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
auto
b_sh_wr
=
threadIdx
.
x
*
b_thread_vecs
;
auto
b_sh_rd
=
threadIdx
.
x
*
b_thread_vecs
;
int
m_gl_rd
=
m_gl_stride
*
(
threadIdx
.
x
/
(
m_sh_stride
))
+
(
threadIdx
.
x
%
(
m_sh_stride
));
m_gl_rd
+=
(
m_sh_stride
)
*
slice_col
;
m_gl_rd
+=
m_gl_rd_delta_o
*
slice_row
;
int
m_sh_wr
=
threadIdx
.
x
;
int
m_sh_rd
=
threadIdx
.
x
%
16
+
(
threadIdx
.
x
/
32
)
*
16
;
auto
m_sh_wr
=
threadIdx
.
x
;
auto
m_sh_rd
=
threadIdx
.
x
%
16
+
(
threadIdx
.
x
/
32
)
*
16
;
int
s_gl_rd
;
if
constexpr
(
group_blocks
==
-
1
)
{
...
...
@@ -291,7 +291,7 @@ __global__ void Marlin_24(
s_sh_stride
*
slice_col
+
threadIdx
.
x
;
}
int
s_sh_wr
=
threadIdx
.
x
;
auto
s_sh_wr
=
threadIdx
.
x
;
int
s_sh_rd
;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
...
...
@@ -516,7 +516,7 @@ __global__ void Marlin_24(
auto
thread_block_reduce
=
[
&
]()
{
constexpr
int
red_off
=
threads
/
b_sh_stride_threads
/
2
;
if
(
red_off
>=
1
)
{
int
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
auto
red_idx
=
threadIdx
.
x
/
b_sh_stride_threads
;
constexpr
int
red_sh_stride
=
b_sh_stride_threads
*
4
*
2
;
constexpr
int
red_sh_delta
=
b_sh_stride_threads
;
int
red_sh_rd
=
red_sh_stride
*
(
threadIdx
.
x
/
b_sh_stride_threads
)
+
...
...
@@ -583,7 +583,7 @@ __global__ void Marlin_24(
8
*
(
threadIdx
.
x
/
32
)
+
(
threadIdx
.
x
%
32
)
/
4
;
c_gl_wr
+=
(
2
*
thread_n_blocks
)
*
slice_col
;
constexpr
int
c_sh_wr_delta
=
active_threads
;
int
c_sh_wr
=
threadIdx
.
x
;
auto
c_sh_wr
=
threadIdx
.
x
;
int
col
=
2
*
((
threadIdx
.
x
%
32
)
%
4
);
...
...
csrc/quantization/utils.cuh
0 → 100644
View file @
fcfc474d
#pragma once
/**
* Quantization utilities including:
* Adjusted maximum values for qtypes.
* Minimum scaling factors for qtypes.
*/
#include <cmath>
#include <torch/types.h>
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
#define MAYBE_HOST_DEVICE C10_HOST_DEVICE
#else
#include <ATen/hip/HIPContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
// ROCm doesn't seem to need C10_HOST_DEVICE for static constexpr
#define MAYBE_HOST_DEVICE
#endif
template
<
typename
T
,
typename
=
std
::
enable_if_t
<
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fn
>
||
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fnuz
>
||
std
::
is_same_v
<
T
,
int8_t
>>>
struct
quant_type_max
{
static
constexpr
T
val
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
};
// Using the default max value from pytorch (240.0 0x7F) will cause accuracy
// issues when running dynamic quantization. Here use 224.0 0x7E for rocm.
template
<
>
struct
quant_type_max
<
c10
::
Float8_e4m3fnuz
>
{
static
constexpr
c10
::
Float8_e4m3fnuz
val
()
{
return
c10
::
Float8_e4m3fnuz
(
0x7E
,
c10
::
Float8_e4m3fnuz
::
from_bits
());
}
};
template
<
typename
T
>
MAYBE_HOST_DEVICE
static
constexpr
T
quant_type_max_v
=
quant_type_max
<
T
>::
val
();
template
<
typename
T
,
typename
=
std
::
enable_if_t
<
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fn
>
||
std
::
is_same_v
<
T
,
c10
::
Float8_e4m3fnuz
>
||
std
::
is_same_v
<
T
,
int8_t
>>>
struct
min_scaling_factor
{
C10_DEVICE
C10_ALWAYS_INLINE
static
float
val
()
{
return
1.0
f
/
(
quant_type_max_v
<
T
>
*
512.0
f
);
}
};
template
<
>
struct
min_scaling_factor
<
int8_t
>
{
C10_DEVICE
C10_ALWAYS_INLINE
static
float
val
()
{
return
std
::
numeric_limits
<
float
>::
epsilon
();
}
};
\ No newline at end of file
csrc/rocm/attention.cu
View file @
fcfc474d
...
...
@@ -272,6 +272,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
...
...
@@ -284,18 +285,25 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
int
max_ctx_blocks
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
// clang-format on
constexpr
int
NWARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
laneid
=
threadIdx
.
x
%
WARP_SIZE
;
const
auto
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
const
auto
laneid
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
lane4id
=
laneid
%
4
;
const
int
lane16id
=
laneid
%
16
;
const
int
rowid
=
laneid
/
16
;
const
int
seq_idx
=
blockIdx
.
x
;
const
int
partition_idx
=
blockIdx
.
y
;
const
auto
seq_idx
=
blockIdx
.
x
;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if
(
query_start_loc_ptr
!=
nullptr
&&
(
query_start_loc_ptr
[
seq_idx
+
1
]
-
query_start_loc_ptr
[
seq_idx
])
!=
1
)
{
return
;
}
const
auto
partition_idx
=
blockIdx
.
y
;
constexpr
int
T_PAR_SIZE
=
256
;
// token partition size set to 256
const
int
max_num_partitions
=
gridDim
.
y
;
const
auto
max_num_partitions
=
gridDim
.
y
;
const
int
context_len
=
context_lens
[
seq_idx
];
...
...
@@ -346,9 +354,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
// can be interpreted as B8x16 for 8 bit types
_B16x8
Klocal
[
TLOOP
][
QKHELOOP
];
const
int
wg_start_head_idx
=
blockIdx
.
z
*
GQA_RATIO
;
const
int
wg_start_kv_head_idx
=
blockIdx
.
z
;
const
int
total_num_heads
=
gridDim
.
z
*
GQA_RATIO
;
const
auto
wg_start_head_idx
=
blockIdx
.
z
*
GQA_RATIO
;
const
auto
wg_start_kv_head_idx
=
blockIdx
.
z
;
const
auto
total_num_heads
=
gridDim
.
z
*
GQA_RATIO
;
// for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps
// each mfma takes QH16xT16x16HE across warp
...
...
@@ -377,9 +385,10 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
// fetch Q in shared across warps and then write to registers
const
int
local_qhead_idx
=
4
*
warpid
+
rowid
;
const
int
global_qhead_idx
=
wg_start_head_idx
+
local_qhead_idx
;
const
int64_t
seq_idx64
=
static_cast
<
int64_t
>
(
seq_idx
);
const
int64_t
query_start_off
=
static_cast
<
int64_t
>
(
query_start_loc_ptr
?
query_start_loc_ptr
[
seq_idx
]
:
seq_idx
);
const
scalar_t
*
q_ptr
=
q
+
seq_idx64
*
q_stride
+
global_qhead_idx
*
HEAD_SIZE
;
q
+
query_start_off
*
q_stride
+
global_qhead_idx
*
HEAD_SIZE
;
const
int
qhead_element
=
lane16id
*
CONTIGUOUS_SCALAR_ELEMS_16B
;
if
((
local_qhead_idx
<
GQA_RATIO
)
&&
(
qhead_element
<
HEAD_SIZE
))
{
...
...
@@ -777,6 +786,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
...
...
@@ -789,14 +799,20 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
int
max_ctx_blocks
,
const
float
*
k_scale
,
const
float
*
v_scale
)
{
// clang-format on
constexpr
int
NWARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
laneid
=
threadIdx
.
x
%
WARP_SIZE
;
const
auto
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
const
auto
laneid
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
lane4id
=
laneid
%
4
;
const
int
seq_idx
=
blockIdx
.
x
;
const
int
partition_idx
=
blockIdx
.
y
;
const
int
partition_size
=
blockDim
.
x
;
const
int
max_num_partitions
=
gridDim
.
y
;
const
auto
seq_idx
=
blockIdx
.
x
;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if
(
query_start_loc_ptr
!=
nullptr
&&
(
query_start_loc_ptr
[
seq_idx
+
1
]
-
query_start_loc_ptr
[
seq_idx
]
!=
1
))
{
return
;
}
const
auto
partition_idx
=
blockIdx
.
y
;
const
auto
partition_size
=
blockDim
.
x
;
const
auto
max_num_partitions
=
gridDim
.
y
;
const
int
context_len
=
context_lens
[
seq_idx
];
const
int
partition_start_token_idx
=
partition_idx
*
partition_size
;
...
...
@@ -838,8 +854,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
qk_max
[
h
]
=
-
FLT_MAX
;
}
const
int
wg_start_head_idx
=
blockIdx
.
z
*
GQA_RATIO
;
const
int
wg_start_kv_head_idx
=
blockIdx
.
z
;
const
auto
wg_start_head_idx
=
blockIdx
.
z
*
GQA_RATIO
;
const
auto
wg_start_kv_head_idx
=
blockIdx
.
z
;
const
int
warp_start_token_idx
=
partition_start_token_idx
+
warpid
*
WARP_SIZE
;
...
...
@@ -857,7 +873,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
// token id within partition
const
int
local_token_idx
=
threadIdx
.
x
;
const
auto
local_token_idx
=
threadIdx
.
x
;
// token id within sequence
const
int
global_token_idx
=
partition_start_token_idx
+
local_token_idx
;
...
...
@@ -882,9 +898,11 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
}
// fetch q elements
// every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elems
// every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elemsc
const
int64_t
query_start_off
=
static_cast
<
int64_t
>
(
query_start_loc_ptr
?
query_start_loc_ptr
[
seq_idx
]
:
seq_idx
);
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
wg_start_head_idx
*
HEAD_SIZE
;
q
+
query_start_off
*
q_stride
+
wg_start_head_idx
*
HEAD_SIZE
;
const
_B16x8
*
q_ptrh8
=
reinterpret_cast
<
const
_B16x8
*>
(
q_ptr
);
const
int
qhead_elemh8
=
laneid
/
4
;
...
...
@@ -1126,7 +1144,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
__syncthreads
();
const
int
num_heads
=
gridDim
.
z
*
GQA_RATIO
;
const
auto
num_heads
=
gridDim
.
z
*
GQA_RATIO
;
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
partition_idx
;
float
*
exp_sums_ptr
=
...
...
@@ -1267,15 +1285,24 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions, head_size]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_partitions
)
{
const
int
num_heads
=
gridDim
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
seq_idx
=
blockIdx
.
y
;
const
auto
num_heads
=
gridDim
.
x
;
const
auto
head_idx
=
blockIdx
.
x
;
const
auto
seq_idx
=
blockIdx
.
y
;
// NOTE queries with sequence len > 1 are prefills and taken care by another
// kernel.
if
(
query_start_loc_ptr
!=
nullptr
&&
(
query_start_loc_ptr
[
seq_idx
+
1
]
-
query_start_loc_ptr
[
seq_idx
]
!=
1
))
{
return
;
}
const
int
context_len
=
context_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
context_len
,
PARTITION_SIZE
);
[[
maybe_unused
]]
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
[[
maybe_unused
]]
const
int
laneid
=
threadIdx
.
x
%
WARP_SIZE
;
const
auto
warpid
=
threadIdx
.
x
/
WARP_SIZE
;
[[
maybe_unused
]]
const
auto
laneid
=
threadIdx
.
x
%
WARP_SIZE
;
__shared__
float
shared_global_exp_sum
;
// max num partitions supported is warp_size * NPAR_LOOPS
...
...
@@ -1294,7 +1321,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
const
auto
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
valid_partition
[
i
]
=
(
partition_no
<
num_partitions
)
?
partition_no
:
last_valid_partition
;
}
...
...
@@ -1324,7 +1351,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
const
auto
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
rescaled_exp_sum
[
i
]
*=
(
partition_no
<
num_partitions
)
?
expf
(
reg_max_logit
[
i
]
-
max_logit
)
:
0.0
f
;
...
...
@@ -1336,7 +1363,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
}
#pragma unroll
for
(
int
i
=
0
;
i
<
NPAR_LOOPS
;
i
++
)
{
const
int
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
const
auto
partition_no
=
i
*
WARP_SIZE
+
threadIdx
.
x
;
shared_exp_sums
[
partition_no
]
=
rescaled_exp_sum
[
i
];
}
...
...
@@ -1439,7 +1466,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
__fdividef
(
1.0
f
,
shared_global_exp_sum
+
1e-6
f
);
acc
*=
inv_global_exp_sum
;
OUTT
*
out_ptr
=
out
+
static_cast
<
int64_t
>
(
seq_idx
)
*
num_heads
*
HEAD_SIZE
+
const
int64_t
query_start_off
=
static_cast
<
int64_t
>
(
query_start_loc_ptr
?
query_start_loc_ptr
[
seq_idx
]
:
seq_idx
);
OUTT
*
out_ptr
=
out
+
query_start_off
*
num_heads
*
HEAD_SIZE
+
static_cast
<
int64_t
>
(
head_idx
)
*
HEAD_SIZE
;
if
constexpr
(
std
::
is_same
<
OUTT
,
bit8_t
>::
value
)
{
out_ptr
[
threadIdx
.
x
]
=
...
...
@@ -1466,6 +1495,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
...
...
@@ -1492,6 +1522,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
...
...
@@ -1515,6 +1546,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads, max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads, max_num_partitions, head_size]
const
int
*
__restrict__
context_lens
,
// [num_seqs]
const
int
*
__restrict__
query_start_loc_ptr
,
// [num_seqs]
const
int
max_num_partitions
)
{
UNREACHABLE_CODE
}
...
...
@@ -1528,10 +1560,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr,
max_num_blocks_per_seq,
\
alibi_slopes_ptr, q_stride, kv_block_stride,
kv_head_stride,
\
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr,
max_ctx_blocks,
\
k_scale_ptr, v_scale_ptr);
block_tables_ptr, context_lens_ptr,
query_start_loc_ptr,
\
max_num_blocks_per_seq,
alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride,
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks,
k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
...
...
@@ -1539,17 +1571,17 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr,
max_num_blocks_per_seq,
\
alibi_slopes_ptr, q_stride, kv_block_stride,
kv_head_stride,
\
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr,
max_ctx_blocks,
\
k_scale_ptr, v_scale_ptr);
block_tables_ptr, context_lens_ptr,
query_start_loc_ptr,
\
max_num_blocks_per_seq,
alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride,
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks,
k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
context_lens_ptr, max_num_partitions);
context_lens_ptr,
query_start_loc_ptr,
max_num_partitions);
template
<
typename
T
,
typename
KVT
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
int
BLOCK_SIZE
,
int
HEAD_SIZE
,
typename
OUTT
,
int
PARTITION_SIZE_OLD
,
...
...
@@ -1559,9 +1591,10 @@ void paged_attention_custom_launcher(
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
const
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
int
max_context_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
)
{
int
num_seqs
=
query
.
size
(
0
);
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
int
max_context_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
)
{
int
num_seqs
=
block_tables
.
size
(
0
);
int
num_heads
=
query
.
size
(
1
);
int
head_size
=
query
.
size
(
2
);
int
max_num_blocks_per_seq
=
block_tables
.
size
(
1
);
...
...
@@ -1569,6 +1602,13 @@ void paged_attention_custom_launcher(
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
// NOTE: query start location is optional for V0 decode should not be used.
// If batch contains mix of prefills and decode, prefills should be skipped.
const
int
*
query_start_loc_ptr
=
query_start_loc
?
reinterpret_cast
<
const
int
*>
(
query_start_loc
.
value
().
data_ptr
())
:
nullptr
;
// NOTE: alibi_slopes is optional.
const
float
*
alibi_slopes_ptr
=
alibi_slopes
...
...
@@ -1700,8 +1740,8 @@ void paged_attention_custom_launcher(
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens,
max_contex
t_l
en
, \
alibi_slopes, k_scale, v_scale);
num_kv_heads, scale, block_tables, context_lens,
query_star
t_l
oc
, \
max_context_len,
alibi_slopes, k_scale, v_scale);
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
PSIZE) \
...
...
@@ -1750,6 +1790,7 @@ void paged_attention(
double
scale
,
torch
::
Tensor
&
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
torch
::
Tensor
&
context_lens
,
// [num_seqs]
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
// [num_seqs]
int64_t
block_size
,
int64_t
max_context_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
...
...
csrc/rocm/ops.h
View file @
fcfc474d
...
...
@@ -7,8 +7,9 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int64_t
num_kv_heads
,
double
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
context_lens
,
int64_t
block_size
,
int64_t
max_context_len
,
torch
::
Tensor
&
context_lens
,
const
std
::
optional
<
torch
::
Tensor
>&
query_start_loc
,
int64_t
block_size
,
int64_t
max_context_len
,
const
std
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
torch
::
Tensor
&
k_scale
,
torch
::
Tensor
&
v_scale
);
csrc/rocm/torch_bindings.cpp
View file @
fcfc474d
...
...
@@ -23,7 +23,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads,"
" float scale, Tensor block_tables,"
" Tensor context_lens, int block_size,"
" Tensor context_lens,"
" Tensor? query_start_loc,"
" int block_size,"
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
...
...
csrc/torch_bindings.cpp
View file @
fcfc474d
...
...
@@ -31,6 +31,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"weak_ref_tensor(Tensor input) -> Tensor"
);
ops
.
impl
(
"weak_ref_tensor"
,
torch
::
kCUDA
,
&
weak_ref_tensor
);
ops
.
def
(
"get_cuda_view_from_cpu_tensor(Tensor cpu_tensor) -> Tensor"
);
ops
.
impl
(
"get_cuda_view_from_cpu_tensor"
,
torch
::
kCPU
,
&
get_cuda_view_from_cpu_tensor
);
// Attention ops
// Compute the attention between an input query and the cached
// keys/values using PagedAttention.
...
...
@@ -481,7 +485,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif
// Dequantization for GGML.
ops
.
def
(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor"
);
ops
.
def
(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor"
);
ops
.
impl
(
"ggml_dequantize"
,
torch
::
kCUDA
,
&
ggml_dequantize
);
// mmvq kernel for GGML.
...
...
@@ -555,6 +561,35 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"
);
ops
.
impl
(
"cutlass_scaled_mm_supports_fp8"
,
&
cutlass_scaled_mm_supports_fp8
);
// Check if cutlass grouped gemm is supported for CUDA devices of the given
// capability
ops
.
def
(
"cutlass_group_gemm_supported(int cuda_device_capability) -> bool"
);
ops
.
impl
(
"cutlass_group_gemm_supported"
,
&
cutlass_group_gemm_supported
);
// CUTLASS w8a8 grouped GEMM
ops
.
def
(
"cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, "
" Tensor a_scales, Tensor b_scales, Tensor expert_offsets, "
" Tensor problem_sizes, Tensor a_strides, "
" Tensor b_strides, Tensor c_strides) -> ()"
,
{
stride_tag
});
ops
.
impl
(
"cutlass_moe_mm"
,
torch
::
kCUDA
,
&
cutlass_moe_mm
);
// A function that computes data required to run fused MoE with w8a8 grouped
// GEMM. It takes topk_ids as an input, and computes expert_offsets
// (token start indices of each expert). In addition to this, it computes
// problem sizes for each expert's multiplication used by the two mms called
// from fused MoE operation, and arrays with permutations required to shuffle
// and de-shuffle the input/output of the fused operation.
ops
.
def
(
"get_cutlass_moe_mm_data(Tensor topk_ids, Tensor! expert_offsets, "
" Tensor! problem_sizes1, Tensor! problem_sizes2, "
" Tensor! input_permutation, "
" Tensor! output_permutation, int num_experts, "
" int n, int k) -> ()"
,
{
stride_tag
});
ops
.
impl
(
"get_cutlass_moe_mm_data"
,
torch
::
kCUDA
,
&
get_cutlass_moe_mm_data
);
// Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3)
ops
.
def
(
"cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> "
...
...
@@ -792,7 +827,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
// Custom all-reduce kernels
custom_ar
.
def
(
"init_custom_ar(int[] ipc_tensors, Tensor rank_data, "
"int rank, bool full
_nvlink
) -> int"
);
"int rank, bool full
y_connected
) -> int"
);
custom_ar
.
impl
(
"init_custom_ar"
,
torch
::
kCUDA
,
&
init_custom_ar
);
custom_ar
.
def
(
"all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, "
...
...
@@ -805,6 +840,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
custom_ar
.
def
(
"register_buffer"
,
&
register_buffer
);
custom_ar
.
def
(
"get_graph_buffer_ipc_meta"
,
&
get_graph_buffer_ipc_meta
);
custom_ar
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
);
custom_ar
.
def
(
"allocate_shared_buffer_and_handle"
,
&
allocate_shared_buffer_and_handle
);
custom_ar
.
def
(
"open_mem_handle(Tensor mem_handle) -> int"
,
&
open_mem_handle
);
...
...
Dockerfile
→
docker/
Dockerfile
View file @
fcfc474d
File moved
Dockerfile.arm
→
docker/
Dockerfile.arm
View file @
fcfc474d
File moved
docker/Dockerfile.cpu
0 → 100644
View file @
fcfc474d
# This vLLM Dockerfile is used to construct image that can build and run vLLM on x86 CPU platform.
#
# Build targets:
# vllm-openai (default): used for serving deployment
# vllm-test: used for CI tests
# vllm-dev: used for development
#
# Build arguments:
# PYTHON_VERSION=3.12 (default)|3.11|3.10|3.9
# VLLM_CPU_DISABLE_AVX512=false (default)|true
#
######################### BASE IMAGE #########################
FROM ubuntu:22.04 AS base
WORKDIR /workspace/
ARG PYTHON_VERSION=3.12
ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu"
# Install minimal dependencies and uv
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get update -y \
&& apt-get install -y --no-install-recommends ccache git curl wget ca-certificates \
gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 \
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 \
&& curl -LsSf https://astral.sh/uv/install.sh | sh
ENV CCACHE_DIR=/root/.cache/ccache
ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache
ENV PATH="/root/.local/bin:$PATH"
ENV VIRTUAL_ENV="/opt/venv"
RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
ENV UV_HTTP_TIMEOUT=500
# Install Python dependencies
ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
ENV UV_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
ENV UV_INDEX_STRATEGY="unsafe-best-match"
ENV UV_LINK_MODE="copy"
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,src=requirements/common.txt,target=requirements/common.txt \
--mount=type=bind,src=requirements/cpu.txt,target=requirements/cpu.txt \
uv pip install --upgrade pip && \
uv pip install -r requirements/cpu.txt
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install intel-openmp==2024.2.1 intel_extension_for_pytorch==2.6.0
ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/opt/venv/lib/libiomp5.so:$LD_PRELOAD"
RUN echo 'ulimit -c 0' >> ~/.bashrc
######################### BUILD IMAGE #########################
FROM base AS vllm-build
ARG GIT_REPO_CHECK=0
# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ...
ARG VLLM_CPU_DISABLE_AVX512
ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
WORKDIR /workspace/vllm
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,src=requirements/build.txt,target=requirements/build.txt \
uv pip install -r requirements/build.txt
COPY . .
RUN --mount=type=bind,source=.git,target=.git \
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=cache,target=/root/.cache/ccache \
--mount=type=bind,source=.git,target=.git \
VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel
######################### DEV IMAGE #########################
FROM vllm-build AS vllm-dev
WORKDIR /workspace/vllm
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
apt-get install -y --no-install-recommends vim numactl
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install -e tests/vllm_test_utils
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=cache,target=/root/.cache/ccache \
--mount=type=bind,source=.git,target=.git \
VLLM_TARGET_DEVICE=cpu python3 setup.py develop
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install -r requirements/dev.txt && \
pre-commit install --hook-type pre-commit --hook-type commit-msg
ENTRYPOINT ["bash"]
######################### TEST IMAGE #########################
FROM base AS vllm-test
WORKDIR /workspace/
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,src=requirements/test.txt,target=requirements/test.txt \
uv pip install -r requirements/test.txt
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=bind,from=vllm-build,src=/workspace/vllm/dist,target=dist \
uv pip install dist/*.whl
ADD ./tests/ ./tests/
ADD ./examples/ ./examples/
ADD ./benchmarks/ ./benchmarks/
# install development dependencies (for testing)
RUN --mount=type=cache,target=/root/.cache/uv \
uv pip install -e tests/vllm_test_utils
ENTRYPOINT ["bash"]
######################### RELEASE IMAGE #########################
FROM base AS vllm-openai
WORKDIR /workspace/
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=type=cache,target=/root/.cache/ccache \
--mount=type=bind,from=vllm-build,src=/workspace/vllm/dist,target=dist \
uv pip install dist/*.whl
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
Dockerfile.hpu
→
docker/
Dockerfile.hpu
View file @
fcfc474d
File moved
Dockerfile.neuron
→
docker/
Dockerfile.neuron
View file @
fcfc474d
File moved
Prev
1
2
3
4
5
6
7
8
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment