Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0640f227
Commit
0640f227
authored
Sep 09, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.6.0' into v0.6.0-dev
parents
82f1ffdf
32e7db25
Changes
335
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1680 additions
and
401 deletions
+1680
-401
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
+86
-0
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
+346
-0
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+3
-0
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+1
-2
vllm/model_executor/layers/quantization/awq_triton.py
vllm/model_executor/layers/quantization/awq_triton.py
+304
-0
vllm/model_executor/layers/quantization/bitsandbytes.py
vllm/model_executor/layers/quantization/bitsandbytes.py
+195
-36
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+19
-18
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+287
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+0
-2
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
...pressed_tensors/schemes/compressed_tensors_unquantized.py
+0
-49
vllm/model_executor/layers/quantization/experts_int8.py
vllm/model_executor/layers/quantization/experts_int8.py
+15
-11
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+21
-13
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+26
-29
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+58
-45
vllm/model_executor/layers/quantization/gptq_marlin_24.py
vllm/model_executor/layers/quantization/gptq_marlin_24.py
+49
-53
vllm/model_executor/layers/quantization/neuron_quant.py
vllm/model_executor/layers/quantization/neuron_quant.py
+67
-0
vllm/model_executor/layers/quantization/qqq.py
vllm/model_executor/layers/quantization/qqq.py
+51
-63
vllm/model_executor/layers/quantization/tpu_int8.py
vllm/model_executor/layers/quantization/tpu_int8.py
+11
-10
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+0
-27
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+141
-43
No files found.
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
0 → 100644
View file @
0640f227
# Copyright (c) 2024, Tri Dao.
from
typing
import
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
def
causal_conv1d_fn
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
seq_idx
:
Optional
[
torch
.
Tensor
]
=
None
,
initial_states
:
Optional
[
torch
.
Tensor
]
=
None
,
return_final_states
:
bool
=
False
,
final_states_out
=
None
,
activation
:
str
=
"silu"
,
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
if
x
.
stride
(
2
)
!=
1
and
x
.
stride
(
1
)
!=
1
:
x
=
x
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
is
not
None
else
None
if
seq_idx
is
not
None
:
assert
(
initial_states
is
None
),
"initial_states must be None if seq_idx is not None"
assert
(
not
return_final_states
),
"If seq_idx is not None, we don't return final_states_out"
seq_idx
=
seq_idx
.
contiguous
()
if
seq_idx
is
not
None
else
None
if
initial_states
is
not
None
and
(
initial_states
.
stride
(
2
)
!=
1
and
initial_states
.
stride
(
1
)
!=
1
):
initial_states
=
initial_states
.
contiguous
()
if
return_final_states
:
assert
(
x
.
stride
(
1
)
==
1
),
"Only channel-last layout support returning final_states_out"
if
final_states_out
is
not
None
:
assert
(
final_states_out
.
stride
(
2
)
==
1
or
final_states_out
.
stride
(
1
)
==
1
)
else
:
batch
,
dim
,
seqlen
=
x
.
shape
width
=
weight
.
shape
[
1
]
final_states_out
=
torch
.
empty
(
batch
,
width
-
1
,
dim
,
device
=
x
.
device
,
dtype
=
x
.
dtype
).
transpose
(
1
,
2
)
else
:
final_states_out
=
None
out
=
ops
.
causal_conv1d_fwd
(
x
,
weight
,
bias
,
seq_idx
,
initial_states
,
final_states_out
,
activation
in
[
"silu"
,
"swish"
])
return
(
out
,
None
)
if
not
return_final_states
else
(
out
,
final_states_out
)
def
causal_conv1d_update
(
x
:
torch
.
Tensor
,
conv_state
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
Optional
[
str
]
=
None
):
"""
x: (batch, dim)
conv_state: (batch, dim, width)
weight: (dim, width)
bias: (dim,)
out: (batch, dim)
"""
if
activation
not
in
[
None
,
"silu"
,
"swish"
]:
raise
NotImplementedError
(
"activation must be None, silu, or swish"
)
activation_bool
=
activation
in
[
"silu"
,
"swish"
]
return
ops
.
causal_conv1d_update
(
x
,
conv_state
,
weight
,
bias
,
activation_bool
)
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
0 → 100644
View file @
0640f227
# Copyright (c) 2024, Tri Dao, Albert Gu.
import
torch
import
triton
import
triton.language
as
tl
from
packaging
import
version
from
vllm
import
_custom_ops
as
ops
TRITON3
=
version
.
parse
(
triton
.
__version__
)
>=
version
.
parse
(
"3.0.0"
)
if
TRITON3
:
@
triton
.
jit
def
softplus
(
dt
):
dt
=
tl
.
where
(
dt
<=
20.0
,
tl
.
math
.
log
(
tl
.
math
.
exp
(
dt
)
+
1
),
dt
)
return
dt
else
:
@
triton
.
jit
def
softplus
(
dt
):
dt
=
tl
.
where
(
dt
<=
20.0
,
tl
.
math
.
log1p
(
tl
.
exp
(
dt
)),
dt
)
return
dt
@
triton
.
heuristics
(
{
"HAS_DT_BIAS"
:
lambda
args
:
args
[
"dt_bias_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_D"
:
lambda
args
:
args
[
"D_ptr"
]
is
not
None
})
@
triton
.
heuristics
({
"HAS_Z"
:
lambda
args
:
args
[
"z_ptr"
]
is
not
None
})
@
triton
.
heuristics
(
{
"BLOCK_SIZE_DSTATE"
:
lambda
args
:
triton
.
next_power_of_2
(
args
[
"dstate"
])})
@
triton
.
jit
def
_selective_scan_update_kernel
(
# Pointers to matrices
state_ptr
,
x_ptr
,
dt_ptr
,
dt_bias_ptr
,
A_ptr
,
B_ptr
,
C_ptr
,
D_ptr
,
z_ptr
,
out_ptr
,
# Matrix dimensions
batch
,
nheads
,
dim
,
dstate
,
nheads_ngroups_ratio
,
# Strides
stride_state_batch
,
stride_state_head
,
stride_state_dim
,
stride_state_dstate
,
stride_x_batch
,
stride_x_head
,
stride_x_dim
,
stride_dt_batch
,
stride_dt_head
,
stride_dt_dim
,
stride_dt_bias_head
,
stride_dt_bias_dim
,
stride_A_head
,
stride_A_dim
,
stride_A_dstate
,
stride_B_batch
,
stride_B_group
,
stride_B_dstate
,
stride_C_batch
,
stride_C_group
,
stride_C_dstate
,
stride_D_head
,
stride_D_dim
,
stride_z_batch
,
stride_z_head
,
stride_z_dim
,
stride_out_batch
,
stride_out_head
,
stride_out_dim
,
# Meta-parameters
DT_SOFTPLUS
:
tl
.
constexpr
,
TIE_HDIM
:
tl
.
constexpr
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
HAS_DT_BIAS
:
tl
.
constexpr
,
HAS_D
:
tl
.
constexpr
,
HAS_Z
:
tl
.
constexpr
,
BLOCK_SIZE_DSTATE
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_b
=
tl
.
program_id
(
axis
=
1
)
pid_h
=
tl
.
program_id
(
axis
=
2
)
state_ptr
+=
pid_b
*
stride_state_batch
+
pid_h
*
stride_state_head
x_ptr
+=
pid_b
*
stride_x_batch
+
pid_h
*
stride_x_head
dt_ptr
+=
pid_b
*
stride_dt_batch
+
pid_h
*
stride_dt_head
if
HAS_DT_BIAS
:
dt_bias_ptr
+=
pid_h
*
stride_dt_bias_head
A_ptr
+=
pid_h
*
stride_A_head
B_ptr
+=
pid_b
*
stride_B_batch
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_B_group
C_ptr
+=
pid_b
*
stride_C_batch
+
(
pid_h
//
nheads_ngroups_ratio
)
*
stride_C_group
if
HAS_Z
:
z_ptr
+=
pid_b
*
stride_z_batch
+
pid_h
*
stride_z_head
out_ptr
+=
pid_b
*
stride_out_batch
+
pid_h
*
stride_out_head
offs_m
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_SIZE_DSTATE
)
state_ptrs
=
state_ptr
+
(
offs_m
[:,
None
]
*
stride_state_dim
+
offs_n
[
None
,
:]
*
stride_state_dstate
)
x_ptrs
=
x_ptr
+
offs_m
*
stride_x_dim
dt_ptrs
=
dt_ptr
+
offs_m
*
stride_dt_dim
if
HAS_DT_BIAS
:
dt_bias_ptrs
=
dt_bias_ptr
+
offs_m
*
stride_dt_bias_dim
if
HAS_D
:
D_ptr
+=
pid_h
*
stride_D_head
A_ptrs
=
A_ptr
+
(
offs_m
[:,
None
]
*
stride_A_dim
+
offs_n
[
None
,
:]
*
stride_A_dstate
)
B_ptrs
=
B_ptr
+
offs_n
*
stride_B_dstate
C_ptrs
=
C_ptr
+
offs_n
*
stride_C_dstate
if
HAS_D
:
D_ptrs
=
D_ptr
+
offs_m
*
stride_D_dim
if
HAS_Z
:
z_ptrs
=
z_ptr
+
offs_m
*
stride_z_dim
out_ptrs
=
out_ptr
+
offs_m
*
stride_out_dim
state
=
tl
.
load
(
state_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
),
other
=
0.0
)
x
=
tl
.
load
(
x_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
if
not
TIE_HDIM
:
dt
=
tl
.
load
(
dt_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_DT_BIAS
:
dt
+=
tl
.
load
(
dt_bias_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
if
DT_SOFTPLUS
:
dt
=
softplus
(
dt
)
A
=
tl
.
load
(
A_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
),
other
=
0.0
).
to
(
tl
.
float32
)
dA
=
tl
.
exp
(
A
*
dt
[:,
None
])
else
:
dt
=
tl
.
load
(
dt_ptr
).
to
(
tl
.
float32
)
if
HAS_DT_BIAS
:
dt
+=
tl
.
load
(
dt_bias_ptr
).
to
(
tl
.
float32
)
if
DT_SOFTPLUS
:
dt
=
softplus
(
dt
)
A
=
tl
.
load
(
A_ptr
).
to
(
tl
.
float32
)
dA
=
tl
.
exp
(
A
*
dt
)
# scalar, not a matrix
B
=
tl
.
load
(
B_ptrs
,
mask
=
offs_n
<
dstate
,
other
=
0.0
).
to
(
tl
.
float32
)
C
=
tl
.
load
(
C_ptrs
,
mask
=
offs_n
<
dstate
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_D
:
D
=
tl
.
load
(
D_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_Z
:
z
=
tl
.
load
(
z_ptrs
,
mask
=
offs_m
<
dim
,
other
=
0.0
).
to
(
tl
.
float32
)
dB
=
B
[
None
,
:]
*
dt
[:,
None
]
if
not
TIE_HDIM
else
B
*
dt
state
=
state
*
dA
+
dB
*
x
[:,
None
]
tl
.
store
(
state_ptrs
,
state
,
mask
=
(
offs_m
[:,
None
]
<
dim
)
&
(
offs_n
[
None
,
:]
<
dstate
))
out
=
tl
.
sum
(
state
*
C
[
None
,
:],
axis
=
1
)
if
HAS_D
:
out
+=
x
*
D
if
HAS_Z
:
out
*=
z
*
tl
.
sigmoid
(
z
)
tl
.
store
(
out_ptrs
,
out
,
mask
=
offs_m
<
dim
)
def
selective_state_update
(
state
,
x
,
dt
,
A
,
B
,
C
,
D
=
None
,
z
=
None
,
dt_bias
=
None
,
dt_softplus
=
False
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
x: (batch, dim) or (batch, nheads, dim)
dt: (batch, dim) or (batch, nheads, dim)
A: (dim, dstate) or (nheads, dim, dstate)
B: (batch, dstate) or (batch, ngroups, dstate)
C: (batch, dstate) or (batch, ngroups, dstate)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
Return:
out: (batch, dim) or (batch, nheads, dim)
"""
has_heads
=
state
.
dim
()
>
3
if
state
.
dim
()
==
3
:
state
=
state
.
unsqueeze
(
1
)
if
x
.
dim
()
==
2
:
x
=
x
.
unsqueeze
(
1
)
if
dt
.
dim
()
==
2
:
dt
=
dt
.
unsqueeze
(
1
)
if
A
.
dim
()
==
2
:
A
=
A
.
unsqueeze
(
0
)
if
B
.
dim
()
==
2
:
B
=
B
.
unsqueeze
(
1
)
if
C
.
dim
()
==
2
:
C
=
C
.
unsqueeze
(
1
)
if
D
is
not
None
and
D
.
dim
()
==
1
:
D
=
D
.
unsqueeze
(
0
)
if
z
is
not
None
and
z
.
dim
()
==
2
:
z
=
z
.
unsqueeze
(
1
)
if
dt_bias
is
not
None
and
dt_bias
.
dim
()
==
1
:
dt_bias
=
dt_bias
.
unsqueeze
(
0
)
batch
,
nheads
,
dim
,
dstate
=
state
.
shape
assert
x
.
shape
==
(
batch
,
nheads
,
dim
)
assert
dt
.
shape
==
x
.
shape
assert
A
.
shape
==
(
nheads
,
dim
,
dstate
)
ngroups
=
B
.
shape
[
1
]
assert
nheads
%
ngroups
==
0
,
"nheads must be divisible by ngroups"
assert
B
.
shape
==
(
batch
,
ngroups
,
dstate
)
assert
C
.
shape
==
B
.
shape
if
D
is
not
None
:
assert
D
.
shape
==
(
nheads
,
dim
)
if
z
is
not
None
:
assert
z
.
shape
==
x
.
shape
if
dt_bias
is
not
None
:
assert
dt_bias
.
shape
==
(
nheads
,
dim
)
out
=
torch
.
empty_like
(
x
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
dim
,
META
[
'BLOCK_SIZE_M'
]),
batch
,
nheads
)
z_strides
=
((
z
.
stride
(
0
),
z
.
stride
(
1
),
z
.
stride
(
2
))
if
z
is
not
None
else
(
0
,
0
,
0
))
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M
,
num_warps
=
((
32
,
4
)
if
dstate
<=
16
else
((
16
,
4
)
if
dstate
<=
32
else
((
8
,
4
)
if
dstate
<=
64
else
((
4
,
4
)
if
dstate
<=
128
else
((
4
,
8
))))))
tie_hdim
=
A
.
stride
(
-
1
)
==
0
and
A
.
stride
(
-
2
)
==
0
and
dt
.
stride
(
-
1
)
==
0
and
dt_bias
.
stride
(
-
1
)
==
0
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
_selective_scan_update_kernel
[
grid
](
state
,
x
,
dt
,
dt_bias
,
A
,
B
,
C
,
D
,
z
,
out
,
batch
,
nheads
,
dim
,
dstate
,
nheads
//
ngroups
,
state
.
stride
(
0
),
state
.
stride
(
1
),
state
.
stride
(
2
),
state
.
stride
(
3
),
x
.
stride
(
0
),
x
.
stride
(
1
),
x
.
stride
(
2
),
dt
.
stride
(
0
),
dt
.
stride
(
1
),
dt
.
stride
(
2
),
*
(
dt_bias
.
stride
(
0
),
dt_bias
.
stride
(
1
))
if
dt_bias
is
not
None
else
0
,
A
.
stride
(
0
),
A
.
stride
(
1
),
A
.
stride
(
2
),
B
.
stride
(
0
),
B
.
stride
(
1
),
B
.
stride
(
2
),
C
.
stride
(
0
),
C
.
stride
(
1
),
C
.
stride
(
2
),
*
(
D
.
stride
(
0
),
D
.
stride
(
1
))
if
D
is
not
None
else
0
,
z_strides
[
0
],
z_strides
[
1
],
z_strides
[
2
],
out
.
stride
(
0
),
out
.
stride
(
1
),
out
.
stride
(
2
),
dt_softplus
,
tie_hdim
,
BLOCK_SIZE_M
,
num_warps
=
num_warps
,
)
if
not
has_heads
:
out
=
out
.
squeeze
(
1
)
return
out
def
selective_scan_fn
(
u
,
delta
,
A
,
B
,
C
,
D
=
None
,
z
=
None
,
delta_bias
=
None
,
delta_softplus
=
False
,
return_last_state
=
False
,
position_indices
=
None
,
prev_state
=
None
):
"""if return_last_state is True, returns (out, last_state)
last_state has shape (batch, dim, dstate).
"""
if
u
.
stride
(
-
1
)
!=
1
:
u
=
u
.
contiguous
()
if
delta
.
stride
(
-
1
)
!=
1
:
delta
=
delta
.
contiguous
()
if
D
is
not
None
:
D
=
D
.
contiguous
()
if
B
.
stride
(
-
1
)
!=
1
:
B
=
B
.
contiguous
()
if
C
.
stride
(
-
1
)
!=
1
:
C
=
C
.
contiguous
()
if
z
is
not
None
and
z
.
stride
(
-
1
)
!=
1
:
z
=
z
.
contiguous
()
if
B
.
dim
()
==
3
:
B
=
B
.
unsqueeze
(
1
)
if
C
.
dim
()
==
3
:
C
=
C
.
unsqueeze
(
1
)
n_chunks
=
int
((
u
.
shape
[
-
1
]
+
2048
-
1
)
/
2048
)
x
=
torch
.
zeros
((
u
.
shape
[
0
],
u
.
shape
[
1
],
n_chunks
,
int
(
A
.
shape
[
1
]
*
2
),
),
device
=
u
.
device
,
dtype
=
torch
.
float32
,
requires_grad
=
False
)
x
[:,
:,
0
,
0
::
2
]
=
1
if
prev_state
is
not
None
:
x
[:,
:,
0
,
1
::
2
].
copy_
(
prev_state
)
out
,
x
,
*
rest
=
ops
.
selective_scan_fwd
(
u
,
delta
,
A
,
B
,
C
,
D
,
z
,
delta_bias
,
delta_softplus
,
position_indices
,
x
)
last_state
=
x
[:,
:,
-
1
,
1
::
2
]
# (batch, dim, dstate)
if
z
is
None
:
return
out
if
not
return_last_state
else
(
out
,
last_state
)
else
:
out_z
=
rest
[
0
]
return
out_z
if
not
return_last_state
else
(
out_z
,
last_state
)
vllm/model_executor/layers/quantization/__init__.py
View file @
0640f227
...
...
@@ -22,6 +22,8 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
GPTQMarlin24Config
)
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.neuron_quant
import
(
NeuronQuantConfig
)
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
...
...
@@ -46,6 +48,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
"experts_int8"
:
ExpertsInt8Config
,
"neuron_quant"
:
NeuronQuantConfig
,
}
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
0640f227
...
...
@@ -218,5 +218,4 @@ class AWQLinearMethod(LinearMethodBase):
if
bias
is
not
None
:
out
.
add_
(
bias
)
return
out
.
reshape
(
out_shape
)
return
out
.
reshape
(
out_shape
)
\ No newline at end of file
vllm/model_executor/layers/quantization/awq_triton.py
0 → 100644
View file @
0640f227
import
torch
import
triton
import
triton.language
as
tl
AWQ_TRITON_SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
@
triton
.
jit
def
awq_dequantize_kernel
(
qweight_ptr
,
# quantized matrix
scales_ptr
,
# scales, per group
zeros_ptr
,
# zeros, per group
group_size
,
# Should always be one of the supported group sizes
result_ptr
,
# Output matrix
num_cols
,
# input num cols in qweight
num_rows
,
# input num rows in qweight
BLOCK_SIZE_X
:
tl
.
constexpr
,
BLOCK_SIZE_Y
:
tl
.
constexpr
):
# Setup the pids.
pid_x
=
tl
.
program_id
(
axis
=
0
)
pid_y
=
tl
.
program_id
(
axis
=
1
)
# Compute offsets and masks for qweight_ptr.
offsets_y
=
pid_y
*
BLOCK_SIZE_Y
+
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
)
//
8
offsets
=
num_cols
*
offsets_y
[:,
None
]
+
offsets_x
[
None
,
:]
masks_y
=
offsets_y
<
num_rows
masks_x
=
offsets_x
<
num_cols
masks
=
masks_y
[:,
None
]
&
masks_x
[
None
,
:]
# Compute offsets and masks for result output ptr.
result_offsets_y
=
pid_y
*
BLOCK_SIZE_Y
+
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
result_offsets_x
=
pid_x
*
BLOCK_SIZE_X
*
8
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
)
result_offsets
=
(
8
*
num_cols
*
result_offsets_y
[:,
None
]
+
result_offsets_x
[
None
,
:])
result_masks_y
=
result_offsets_y
<
num_rows
result_masks_x
=
result_offsets_x
<
num_cols
*
8
result_masks
=
result_masks_y
[:,
None
]
&
result_masks_x
[
None
,
:]
# Load the weights.
iweights
=
tl
.
load
(
qweight_ptr
+
offsets
,
masks
)
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
reverse_awq_order_tensor
=
((
tl
.
arange
(
0
,
2
)
*
4
)[
None
,
:]
+
tl
.
arange
(
0
,
4
)[:,
None
]).
reshape
(
8
)
# Use this to compute a set of shifts that can be used to unpack and
# reorder the values in iweights and zeros.
shifts
=
reverse_awq_order_tensor
*
4
shifts
=
tl
.
broadcast_to
(
shifts
[
None
,
:],
(
BLOCK_SIZE_Y
*
BLOCK_SIZE_X
,
8
))
shifts
=
tl
.
reshape
(
shifts
,
(
BLOCK_SIZE_Y
,
BLOCK_SIZE_X
*
8
))
# Unpack and reorder: shift out the correct 4-bit value and mask.
iweights
=
(
iweights
>>
shifts
)
&
0xF
# Compute zero offsets and masks.
zero_offsets_y
=
(
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
//
group_size
)
zero_offsets_x
=
pid_x
*
BLOCK_SIZE_X
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
)
//
8
zero_offsets
=
num_cols
*
zero_offsets_y
[:,
None
]
+
zero_offsets_x
[
None
,
:]
zero_masks_y
=
zero_offsets_y
<
num_rows
//
group_size
zero_masks_x
=
zero_offsets_x
<
num_cols
zero_masks
=
zero_masks_y
[:,
None
]
&
zero_masks_x
[
None
,
:]
# Load the zeros.
zeros
=
tl
.
load
(
zeros_ptr
+
zero_offsets
,
zero_masks
)
# Unpack and reorder: shift out the correct 4-bit value and mask.
zeros
=
(
zeros
>>
shifts
)
&
0xF
# Compute scale offsets and masks.
scale_offsets_y
=
(
pid_y
*
BLOCK_SIZE_Y
//
group_size
+
tl
.
arange
(
0
,
BLOCK_SIZE_Y
)
//
group_size
)
scale_offsets_x
=
(
pid_x
*
BLOCK_SIZE_X
*
8
+
tl
.
arange
(
0
,
BLOCK_SIZE_X
*
8
))
scale_offsets
=
(
num_cols
*
8
*
scale_offsets_y
[:,
None
]
+
scale_offsets_x
[
None
,
:])
scale_masks_y
=
scale_offsets_y
<
num_rows
//
group_size
scale_masks_x
=
scale_offsets_x
<
num_cols
*
8
scale_masks
=
scale_masks_y
[:,
None
]
&
scale_masks_x
[
None
,
:]
# Load the scales.
scales
=
tl
.
load
(
scales_ptr
+
scale_offsets
,
scale_masks
)
# Dequantize.
iweights
=
(
iweights
-
zeros
)
*
scales
iweights
=
iweights
.
to
(
result_ptr
.
type
.
element_ty
)
# Finally, store.
tl
.
store
(
result_ptr
+
result_offsets
,
iweights
,
result_masks
)
@
triton
.
jit
def
awq_gemm_kernel
(
a_ptr
,
b_ptr
,
c_ptr
,
zeros_ptr
,
scales_ptr
,
M
,
N
,
K
,
group_size
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
axis
=
0
)
pid_z
=
tl
.
program_id
(
1
)
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
# num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
pid_m
=
pid
//
num_pid_n
pid_n
=
pid
%
num_pid_n
accumulator_dtype
=
c_ptr
.
type
.
element_ty
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
# accumulator = tl.arange(0, BLOCK_SIZE_N)
# accumulator = tl.broadcast_to(accumulator[None, :],
# (BLOCK_SIZE_M, BLOCK_SIZE_N))
# accumulator = accumulator & 0x0
# accumulator = accumulator.to(accumulator_dtype)
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
accumulator_dtype
)
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
# that will map given indices to the correct order.
reverse_awq_order_tensor
=
((
tl
.
arange
(
0
,
2
)
*
4
)[
None
,
:]
+
tl
.
arange
(
0
,
4
)[:,
None
]).
reshape
(
8
)
# Create the necessary shifts to use to unpack.
shifts
=
reverse_awq_order_tensor
*
4
shifts
=
tl
.
broadcast_to
(
shifts
[
None
,
:],
(
BLOCK_SIZE_K
*
(
BLOCK_SIZE_N
//
8
),
8
))
shifts
=
tl
.
reshape
(
shifts
,
(
BLOCK_SIZE_K
,
BLOCK_SIZE_N
))
# Offsets and masks.
offsets_am
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
masks_am
=
offsets_am
<
M
offsets_bn
=
(
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
//
8
)
masks_bn
=
offsets_bn
<
N
//
8
offsets_zn
=
(
pid_n
*
(
BLOCK_SIZE_N
//
8
)
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
//
8
)
masks_zn
=
offsets_zn
<
N
//
8
offsets_sn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
masks_sn
=
offsets_sn
<
N
offsets_k
=
pid_z
*
BLOCK_SIZE_K
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
offsets_a
=
K
*
offsets_am
[:,
None
]
+
offsets_k
[
None
,
:]
offsets_b
=
(
N
//
8
)
*
offsets_k
[:,
None
]
+
offsets_bn
[
None
,
:]
a_ptrs
=
a_ptr
+
offsets_a
b_ptrs
=
b_ptr
+
offsets_b
# NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv
# block_offset = BLOCK_SIZE_K * SPLIT_K
# for k in range(0, (K + block_offset - 1) // (block_offset)):
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
*
SPLIT_K
)):
masks_k
=
offsets_k
<
K
masks_a
=
masks_am
[:,
None
]
&
masks_k
[
None
,
:]
a
=
tl
.
load
(
a_ptrs
,
mask
=
masks_a
)
masks_b
=
masks_k
[:,
None
]
&
masks_bn
[
None
,
:]
b
=
tl
.
load
(
b_ptrs
,
mask
=
masks_b
)
# Dequantize b.
offsets_szk
=
(
(
BLOCK_SIZE_K
*
SPLIT_K
*
k
+
pid_z
*
BLOCK_SIZE_K
)
//
group_size
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
//
group_size
)
offsets_z
=
(
N
//
8
)
*
offsets_szk
[:,
None
]
+
offsets_zn
[
None
,
:]
masks_zk
=
offsets_szk
<
K
//
group_size
masks_z
=
masks_zk
[:,
None
]
&
masks_zn
[
None
,
:]
zeros_ptrs
=
zeros_ptr
+
offsets_z
zeros
=
tl
.
load
(
zeros_ptrs
,
mask
=
masks_z
)
offsets_s
=
N
*
offsets_szk
[:,
None
]
+
offsets_sn
[
None
,
:]
masks_sk
=
offsets_szk
<
K
//
group_size
masks_s
=
masks_sk
[:,
None
]
&
masks_sn
[
None
,
:]
scales_ptrs
=
scales_ptr
+
offsets_s
scales
=
tl
.
load
(
scales_ptrs
,
mask
=
masks_s
)
b
=
(
b
>>
shifts
)
&
0xF
zeros
=
(
zeros
>>
shifts
)
&
0xF
b
=
(
b
-
zeros
)
*
scales
b
=
b
.
to
(
c_ptr
.
type
.
element_ty
)
# Accumulate results.
accumulator
=
tl
.
dot
(
a
,
b
,
accumulator
,
out_dtype
=
accumulator_dtype
)
offsets_k
+=
BLOCK_SIZE_K
*
SPLIT_K
a_ptrs
+=
BLOCK_SIZE_K
*
SPLIT_K
b_ptrs
+=
BLOCK_SIZE_K
*
SPLIT_K
*
(
N
//
8
)
c
=
accumulator
.
to
(
c_ptr
.
type
.
element_ty
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
c_ptr
+
N
*
offs_cm
[:,
None
]
+
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
if
SPLIT_K
==
1
:
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
else
:
tl
.
atomic_add
(
c_ptrs
,
c
,
mask
=
c_mask
)
# qweights - [K , M // 8], int32
# scales - [K // G, M ], float16
# zeros - [K // G, M // 8], int32
def
awq_dequantize_triton
(
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
zeros
:
torch
.
Tensor
,
block_size_x
:
int
=
32
,
block_size_y
:
int
=
32
)
->
torch
.
Tensor
:
K
=
qweight
.
shape
[
0
]
M
=
scales
.
shape
[
1
]
group_size
=
qweight
.
shape
[
0
]
//
scales
.
shape
[
0
]
assert
K
>
0
and
M
>
0
assert
scales
.
shape
[
0
]
==
K
//
group_size
and
scales
.
shape
[
1
]
==
M
assert
zeros
.
shape
[
0
]
==
K
//
group_size
and
zeros
.
shape
[
1
]
==
M
//
8
assert
group_size
<=
K
assert
group_size
in
AWQ_TRITON_SUPPORTED_GROUP_SIZES
or
group_size
==
K
# Result tensor:
# number of rows = same as input tensor
# number of cols = 8 x input tensor num cols
result
=
torch
.
empty
(
qweight
.
shape
[
0
],
qweight
.
shape
[
1
]
*
8
,
device
=
qweight
.
device
,
dtype
=
scales
.
dtype
)
Y
=
qweight
.
shape
[
0
]
# num rows
X
=
qweight
.
shape
[
1
]
# num cols
grid
=
lambda
META
:
(
triton
.
cdiv
(
X
,
META
[
'BLOCK_SIZE_X'
]),
triton
.
cdiv
(
Y
,
META
[
'BLOCK_SIZE_Y'
]),
)
awq_dequantize_kernel
[
grid
](
qweight
,
scales
,
zeros
,
group_size
,
result
,
X
,
Y
,
BLOCK_SIZE_X
=
block_size_x
,
BLOCK_SIZE_Y
=
block_size_y
)
return
result
# input - [M, K]
# qweight - [K, N // 8]
# qzeros - [K // G, N // 8]
# scales - [K // G, N]
# split_k_iters - parallelism along K-dimension, int, power of 2.
def
awq_gemm_triton
(
input
:
torch
.
Tensor
,
qweight
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
qzeros
:
torch
.
Tensor
,
split_k_iters
:
int
,
block_size_m
:
int
=
32
,
block_size_n
:
int
=
32
,
block_size_k
:
int
=
32
)
->
torch
.
Tensor
:
M
,
K
=
input
.
shape
N
=
qweight
.
shape
[
1
]
*
8
group_size
=
qweight
.
shape
[
0
]
//
qzeros
.
shape
[
0
]
assert
N
>
0
and
K
>
0
and
M
>
0
assert
qweight
.
shape
[
0
]
==
K
and
qweight
.
shape
[
1
]
==
N
//
8
assert
qzeros
.
shape
[
0
]
==
K
//
group_size
and
qzeros
.
shape
[
1
]
==
N
//
8
assert
scales
.
shape
[
0
]
==
K
//
group_size
and
scales
.
shape
[
1
]
==
N
assert
split_k_iters
&
(
split_k_iters
-
1
)
==
0
and
split_k_iters
!=
0
assert
split_k_iters
<=
32
assert
group_size
<=
K
assert
group_size
in
AWQ_TRITON_SUPPORTED_GROUP_SIZES
or
group_size
==
K
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_SIZE_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_SIZE_N'
]),
split_k_iters
,
)
result
=
torch
.
zeros
((
M
,
N
),
dtype
=
scales
.
dtype
,
device
=
input
.
device
)
# A = input, B = qweight, C = result
# A = M x K, B = K x N, C = M x N
awq_gemm_kernel
[
grid
](
input
,
qweight
,
result
,
qzeros
,
scales
,
M
,
N
,
K
,
group_size
,
BLOCK_SIZE_M
=
block_size_m
,
BLOCK_SIZE_N
=
block_size_n
,
BLOCK_SIZE_K
=
block_size_k
,
SPLIT_K
=
split_k_iters
)
return
result
vllm/model_executor/layers/quantization/bitsandbytes.py
View file @
0640f227
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
...
...
@@ -15,8 +14,28 @@ class BitsAndBytesConfig(QuantizationConfig):
Reference: https://arxiv.org/abs/2305.14314
"""
def
__init__
(
self
,
)
->
None
:
pass
def
__init__
(
self
,
load_in_8bit
:
bool
=
False
,
load_in_4bit
:
bool
=
True
,
bnb_4bit_compute_dtype
:
str
=
"float32"
,
bnb_4bit_quant_type
:
str
=
"fp4"
,
bnb_4bit_use_double_quant
:
bool
=
False
,
llm_int8_enable_fp32_cpu_offload
:
bool
=
False
,
llm_int8_has_fp16_weight
:
bool
=
False
,
llm_int8_skip_modules
:
Optional
[
Any
]
=
None
,
llm_int8_threshold
:
float
=
0.0
,
)
->
None
:
self
.
load_in_8bit
=
load_in_8bit
self
.
load_in_4bit
=
load_in_4bit
self
.
bnb_4bit_compute_dtype
=
bnb_4bit_compute_dtype
self
.
bnb_4bit_quant_type
=
bnb_4bit_quant_type
self
.
bnb_4bit_use_double_quant
=
bnb_4bit_use_double_quant
self
.
llm_int8_enable_fp32_cpu_offload
=
llm_int8_enable_fp32_cpu_offload
self
.
llm_int8_has_fp16_weight
=
llm_int8_has_fp16_weight
self
.
llm_int8_skip_modules
=
llm_int8_skip_modules
self
.
llm_int8_threshold
=
llm_int8_threshold
def
__repr__
(
self
)
->
str
:
return
"BitsAndBytesConfig"
...
...
@@ -41,7 +60,46 @@ class BitsAndBytesConfig(QuantizationConfig):
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"BitsAndBytesConfig"
:
return
cls
()
def
get_safe_value
(
config
,
keys
,
default_value
=
None
):
try
:
value
=
cls
.
get_from_keys
(
config
,
keys
)
return
value
if
value
is
not
None
else
default_value
except
ValueError
:
return
default_value
load_in_8bit
=
get_safe_value
(
config
,
[
"load_in_8bit"
],
default_value
=
False
)
load_in_4bit
=
get_safe_value
(
config
,
[
"load_in_4bit"
],
default_value
=
True
)
bnb_4bit_compute_dtype
=
get_safe_value
(
config
,
[
"bnb_4bit_compute_dtype"
],
default_value
=
"float32"
)
bnb_4bit_quant_type
=
get_safe_value
(
config
,
[
"bnb_4bit_quant_type"
],
default_value
=
"fp4"
)
bnb_4bit_use_double_quant
=
get_safe_value
(
config
,
[
"bnb_4bit_use_double_quant"
],
default_value
=
False
)
llm_int8_enable_fp32_cpu_offload
=
get_safe_value
(
config
,
[
"llm_int8_enable_fp32_cpu_offload"
],
default_value
=
False
)
llm_int8_has_fp16_weight
=
get_safe_value
(
config
,
[
"llm_int8_has_fp16_weight"
],
default_value
=
False
)
llm_int8_skip_modules
=
get_safe_value
(
config
,
[
"llm_int8_skip_modules"
],
default_value
=
[])
llm_int8_threshold
=
get_safe_value
(
config
,
[
"llm_int8_threshold"
],
default_value
=
0.0
)
return
cls
(
load_in_8bit
=
load_in_8bit
,
load_in_4bit
=
load_in_4bit
,
bnb_4bit_compute_dtype
=
bnb_4bit_compute_dtype
,
bnb_4bit_quant_type
=
bnb_4bit_quant_type
,
bnb_4bit_use_double_quant
=
bnb_4bit_use_double_quant
,
llm_int8_enable_fp32_cpu_offload
=
llm_int8_enable_fp32_cpu_offload
,
llm_int8_has_fp16_weight
=
llm_int8_has_fp16_weight
,
llm_int8_skip_modules
=
llm_int8_skip_modules
,
llm_int8_threshold
=
llm_int8_threshold
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"BitsAndBytesLinearMethod"
]:
...
...
@@ -78,39 +136,58 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
quant_ratio
=
0
if
params_dtype
.
is_floating_point
:
quant_ratio
=
torch
.
finfo
(
params_dtype
).
bits
//
torch
.
iinfo
(
torch
.
uint8
).
bits
from
bitsandbytes.nn
import
Int8Params
def
calculate_quant_ratio
(
dtype
):
if
dtype
.
is_floating_point
:
return
torch
.
finfo
(
dtype
).
bits
//
torch
.
iinfo
(
torch
.
uint8
).
bits
else
:
return
torch
.
iinfo
(
dtype
).
bits
//
torch
.
iinfo
(
torch
.
uint8
).
bits
def
create_qweight_for_8bit
():
qweight
=
Int8Params
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
torch
.
int8
),
has_fp16_weights
=
self
.
quant_config
.
llm_int8_has_fp16_weight
,
requires_grad
=
False
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
0
,
"pack_factor"
:
1
,
"use_bitsandbytes_8bit"
:
True
,
"generation"
:
0
})
return
qweight
def
create_qweight_for_4bit
():
quant_ratio
=
calculate_quant_ratio
(
params_dtype
)
total_size
=
input_size_per_partition
*
sum
(
output_partition_sizes
)
if
total_size
%
quant_ratio
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape."
)
qweight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
total_size
//
quant_ratio
,
1
,
dtype
=
torch
.
uint8
),
requires_grad
=
False
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
0
,
"pack_factor"
:
quant_ratio
,
"use_bitsandbytes_4bit"
:
True
})
return
qweight
if
self
.
quant_config
.
load_in_8bit
:
qweight
=
create_qweight_for_8bit
()
else
:
quant_ratio
=
torch
.
iinfo
(
params_dtype
).
bits
//
torch
.
iinfo
(
torch
.
uint8
).
bits
if
input_size_per_partition
*
sum
(
output_partition_sizes
)
%
quant_ratio
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. "
)
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
*
sum
(
output_partition_sizes
)
//
quant_ratio
,
1
,
dtype
=
torch
.
uint8
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
# In bitsandbytes, a tensor of shape [n,m] is quantized to
#[n*m/pack_ratio, 1],so the output_dim is 0
"output_dim"
:
0
,
"pack_factor"
:
quant_ratio
,
"use_bitsandbytes"
:
True
,
})
qweight
=
create_qweight_for_4bit
()
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
...
...
@@ -119,6 +196,88 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
quant_config
.
load_in_8bit
:
return
self
.
_apply_8bit_weight
(
layer
,
x
,
bias
)
else
:
return
self
.
_apply_4bit_weight
(
layer
,
x
,
bias
)
def
_apply_8bit_weight
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# only load the bitsandbytes module when needed
from
bitsandbytes
import
MatmulLtState
,
matmul
original_type
=
x
.
dtype
bf_x
=
x
.
to
(
torch
.
bfloat16
)
qweight
=
layer
.
qweight
offsets
=
qweight
.
bnb_shard_offsets
quant_states
=
qweight
.
bnb_quant_state
matmul_states
=
qweight
.
matmul_state
generation
=
qweight
.
generation
out_dim_0
=
x
.
shape
[
0
]
out_dim_1
=
sum
(
[
quant_state
[
1
].
shape
[
0
]
for
quant_state
in
quant_states
.
items
()])
out
=
torch
.
empty
(
out_dim_0
,
out_dim_1
,
dtype
=
torch
.
float16
,
device
=
x
.
device
)
current_index
=
0
for
i
in
range
(
len
(
quant_states
)):
output_size
=
quant_states
[
i
].
shape
[
0
]
# in profile_run or the first generation of inference,
# create new matmul_states
if
generation
==
0
or
generation
==
1
:
matmul_states
[
i
]
=
MatmulLtState
()
matmul_states
[
i
].
CB
=
qweight
[
offsets
[
i
]:
offsets
[
i
+
1
]]
matmul_states
[
i
].
SCB
=
quant_states
[
i
]
matmul_states
[
i
].
threshold
=
(
self
.
quant_config
.
llm_int8_threshold
)
matmul_states
[
i
].
has_fp16_weights
=
(
self
.
quant_config
.
llm_int8_has_fp16_weight
)
matmul_states
[
i
].
is_training
=
False
if
matmul_states
[
i
].
threshold
>
0.0
and
not
matmul_states
[
i
].
has_fp16_weights
:
matmul_states
[
i
].
use_pool
=
True
new_x
=
bf_x
.
unsqueeze
(
0
)
out
[:,
current_index
:
current_index
+
output_size
]
=
matmul
(
new_x
,
qweight
[
offsets
[
i
]:
offsets
[
i
+
1
]],
state
=
matmul_states
[
i
])
current_index
+=
output_size
# only update the matmul_states if it is not profile_run
if
(
generation
>
0
and
not
self
.
quant_config
.
llm_int8_has_fp16_weight
and
matmul_states
[
i
].
CB
is
not
None
and
matmul_states
[
i
].
CxB
is
not
None
):
del
matmul_states
[
i
].
CB
qweight
[
offsets
[
i
]:
offsets
[
i
+
1
]]
=
matmul_states
[
i
].
CxB
out
=
out
.
to
(
original_type
)
if
bias
is
not
None
:
out
+=
bias
qweight
.
generation
+=
1
return
out
def
_apply_4bit_weight
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# only load the bitsandbytes module when needed
from
bitsandbytes
import
matmul_4bit
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
0640f227
...
...
@@ -3,15 +3,18 @@ from typing import Any, Dict, List, Optional
import
torch
from
pydantic
import
BaseModel
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
# noqa: E501
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe
import
(
# noqa: E501
CompressedTensorsMoEMethod
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensorsScheme
,
CompressedTensorsUnquantized
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
CompressedTensorsScheme
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
,
find_matched_target
,
is_activation_quantization_format
,
...
...
@@ -52,18 +55,25 @@ class CompressedTensorsConfig(QuantizationConfig):
def
get_name
(
self
)
->
str
:
return
"compressed_tensors"
# TODO (@robertgshaw2-neuralmagic): do layer skipping though here
# rather than though create_weights to match other methods
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if
should_ignore_layer
(
prefix
,
ignore
=
self
.
ignore
):
return
UnquantizedLinearMethod
()
if
isinstance
(
layer
,
LinearBase
):
scheme
=
self
.
get_scheme
(
layer
=
layer
,
layer_name
=
prefix
)
layer
.
scheme
=
scheme
return
CompressedTensorsLinearMethod
(
self
)
if
isinstance
(
layer
,
Attention
):
return
CompressedTensorsKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
return
CompressedTensorsMoEMethod
(
self
)
return
None
@
classmethod
...
...
@@ -281,15 +291,11 @@ class CompressedTensorsConfig(QuantizationConfig):
to select the CompressedTensorsScheme used for infernece.
"""
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if
should_ignore_layer
(
layer_name
,
ignore
=
self
.
ignore
):
return
CompressedTensorsUnquantized
()
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
matched_target
=
find_matched_target
(
layer_name
=
layer_name
,
module
=
layer
,
...
...
@@ -327,10 +333,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
details
"""
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer_name
=
extra_weight_attrs
.
get
(
"prefix"
)
scheme
=
self
.
quantization_config
.
get_scheme
(
layer
,
layer_name
)
scheme
.
create_weights
(
layer
.
scheme
.
create_weights
(
layer
=
layer
,
input_size
=
input_size
,
input_size_per_partition
=
input_size_per_partition
,
...
...
@@ -339,8 +342,6 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
params_dtype
=
params_dtype
,
weight_loader
=
weight_loader
)
layer
.
scheme
=
scheme
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
0 → 100644
View file @
0640f227
import
enum
from
enum
import
Enum
from
typing
import
Callable
,
List
,
Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
FusedMoEMethodBase
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
WNA16_SUPPORTED_BITS
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
GPTQMarlinState
(
Enum
):
REPACK
=
enum
.
auto
()
READY
=
enum
.
auto
()
__all__
=
[
"CompressedTensorsMoEMethod"
]
class
CompressedTensorsMoEMethod
(
FusedMoEMethodBase
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
num_bits
=
config
.
num_bits
self
.
packed_factor
=
32
//
config
.
num_bits
self
.
strategy
=
config
.
strategy
.
value
self
.
group_size
=
config
.
group_size
assert
config
.
symmetric
,
(
"Only symmetric quantization is supported for MoE"
)
if
not
(
self
.
quant_config
.
quant_format
==
CompressionFormat
.
pack_quantized
.
value
and
self
.
num_bits
in
WNA16_SUPPORTED_BITS
):
raise
ValueError
(
"For Fused MoE layers, only "
,
f
"
{
CompressionFormat
.
pack_quantized
.
value
}
"
,
"is supported for the following bits: "
,
f
"
{
WNA16_SUPPORTED_BITS
}
"
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
extra_weight_attrs
.
update
({
"is_transposed"
:
True
,
"quant_method"
:
self
.
strategy
})
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
//
self
.
packed_factor
,
2
*
intermediate_size
,
dtype
=
torch
.
int32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_packed"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
//
self
.
packed_factor
,
hidden_size
,
dtype
=
torch
.
int32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_packed"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
if
self
.
strategy
==
"channel"
:
num_groups_w2
=
num_groups_w13
=
1
self
.
group_size
=
-
1
else
:
num_groups_w2
=
intermediate_size
//
self
.
group_size
num_groups_w13
=
hidden_size
//
self
.
group_size
w13_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
num_groups_w13
,
2
*
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_scale
)
set_weight_attrs
(
w13_scale
,
extra_weight_attrs
)
w2_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
num_groups_w2
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_scale
)
set_weight_attrs
(
w2_scale
,
extra_weight_attrs
)
w2_weight_shape
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_shape"
,
w2_weight_shape
)
set_weight_attrs
(
w2_weight_shape
,
extra_weight_attrs
)
w13_weight_shape
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_shape"
,
w13_weight_shape
)
set_weight_attrs
(
w13_weight_shape
,
extra_weight_attrs
)
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx"
,
w13_g_idx
)
set_weight_attrs
(
w13_g_idx
,
extra_weight_attrs
)
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx"
,
w2_g_idx
)
set_weight_attrs
(
w2_g_idx
,
extra_weight_attrs
)
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_g_idx_sort_indices"
,
w13_g_idx_sort_indices
)
set_weight_attrs
(
w13_g_idx_sort_indices
,
extra_weight_attrs
)
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
intermediate_size
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_g_idx_sort_indices"
,
w2_g_idx_sort_indices
)
set_weight_attrs
(
w2_g_idx_sort_indices
,
extra_weight_attrs
)
layer
.
a13_scale
=
None
layer
.
a2_scale
=
None
layer
.
marlin_state
=
GPTQMarlinState
.
REPACK
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
replace_tensor
(
name
,
new_t
):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr
(
layer
,
name
).
resize_
(
new_t
.
shape
)
getattr
(
layer
,
name
).
copy_
(
new_t
)
del
new_t
def
get_scale_perms
(
num_bits
:
int
):
scale_perm
:
List
[
int
]
=
[]
for
i
in
range
(
8
):
scale_perm
.
extend
([
i
+
8
*
j
for
j
in
range
(
8
)])
scale_perm_single
:
List
[
int
]
=
[]
for
i
in
range
(
4
):
scale_perm_single
.
extend
(
[
2
*
i
+
j
for
j
in
[
0
,
1
,
8
,
9
,
16
,
17
,
24
,
25
]])
return
scale_perm
,
scale_perm_single
def
marlin_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
scale_perm
,
scale_perm_single
=
get_scale_perms
(
num_bits
)
if
group_size
<
size_k
and
group_size
!=
-
1
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm
)))[:,
scale_perm
]
else
:
s
=
s
.
reshape
((
-
1
,
len
(
scale_perm_single
)))[:,
scale_perm_single
]
s
=
s
.
reshape
((
-
1
,
size_n
)).
contiguous
()
return
s
def
marlin_moe_permute_scales
(
s
:
torch
.
Tensor
,
size_k
:
int
,
size_n
:
int
,
group_size
:
int
,
num_bits
:
int
):
num_experts
=
s
.
shape
[
0
]
output
=
torch
.
empty
((
num_experts
,
s
.
shape
[
1
],
s
.
shape
[
2
]),
device
=
s
.
device
,
dtype
=
s
.
dtype
)
for
e
in
range
(
num_experts
):
output
[
e
]
=
marlin_permute_scales
(
s
[
e
],
size_k
,
size_n
,
group_size
,
num_bits
)
return
output
size_k2
=
layer
.
w2_weight_packed
.
shape
[
2
]
size_k13
=
layer
.
w13_weight_packed
.
shape
[
2
]
num_experts
=
layer
.
w13_g_idx
.
shape
[
0
]
device
=
layer
.
w13_g_idx
.
device
layer
.
w13_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w13_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
layer
.
w2_g_idx_sort_indices
=
torch
.
nn
.
Parameter
(
torch
.
empty
((
num_experts
,
0
),
dtype
=
torch
.
int32
,
device
=
device
),
requires_grad
=
False
,
)
marlin_w13_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w13_weight_packed
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w13_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
layer
.
w13_weight_packed
.
shape
[
2
],
self
.
num_bits
,
)
replace_tensor
(
"w13_weight_packed"
,
marlin_w13_qweight
)
marlin_w2_qweight
=
ops
.
gptq_marlin_moe_repack
(
layer
.
w2_weight_packed
,
layer
.
w2_g_idx_sort_indices
,
layer
.
w2_weight_packed
.
shape
[
1
]
*
self
.
packed_factor
,
layer
.
w2_weight_packed
.
shape
[
2
],
self
.
num_bits
,
)
replace_tensor
(
"w2_weight_packed"
,
marlin_w2_qweight
)
# Repack scales
marlin_w13_scales
=
marlin_moe_permute_scales
(
layer
.
w13_weight_scale
,
size_k13
,
layer
.
w13_weight_scale
.
shape
[
2
],
self
.
group_size
,
self
.
num_bits
,
)
replace_tensor
(
"w13_weight_scale"
,
marlin_w13_scales
)
marlin_w2_scales
=
marlin_moe_permute_scales
(
layer
.
w2_weight_scale
,
layer
.
w2_weight_scale
.
shape
[
1
]
*
self
.
packed_factor
,
size_k2
,
self
.
group_size
,
self
.
num_bits
,
)
replace_tensor
(
"w2_weight_scale"
,
marlin_w2_scales
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_marlin_moe
)
return
fused_marlin_moe
(
x
,
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
router_logits
,
layer
.
w13_g_idx
,
layer
.
w2_g_idx
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w2_g_idx_sort_indices
,
top_k
,
custom_routing_function
,
renormalize
=
renormalize
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
0640f227
from
.compressed_tensors_scheme
import
CompressedTensorsScheme
from
.compressed_tensors_unquantized
import
CompressedTensorsUnquantized
from
.compressed_tensors_w4a16_24
import
(
W4A16SPARSE24_SUPPORTED_BITS
,
CompressedTensorsW4A16Sparse24
)
from
.compressed_tensors_w8a8_fp8
import
CompressedTensorsW8A8Fp8
...
...
@@ -10,7 +9,6 @@ from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
__all__
=
[
"CompressedTensorsScheme"
,
"CompressedTensorsUnquantized"
,
"CompressedTensorsWNA16"
,
"CompressedTensorsW8A16Fp8"
,
"CompressedTensorsW4A16Sparse24"
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
deleted
100644 → 0
View file @
82f1ffdf
from
typing
import
Callable
,
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.parameter
import
ModelWeightParameter
__all__
=
[
"CompressedTensorsUnquantized"
]
class
CompressedTensorsUnquantized
(
CompressedTensorsScheme
):
"""
Implements the scheme for all layers which are ignored
in the CompressedTensors config. The input and loaded weight are used
in a linear transformation.
"""
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# volta and up
return
70
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# required by torch.compile to be torch.nn.Parameter
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
vllm/model_executor/layers/quantization/experts_int8.py
View file @
0640f227
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
...
...
@@ -96,15 +96,18 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_scale"
,
w2_scale
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
...
...
@@ -114,7 +117,8 @@ class ExpertsInt8MoEMethod(FusedMoEMethodBase):
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
)
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
0640f227
...
...
@@ -15,8 +15,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
create_per_channel_scale_param
)
from
vllm.model_executor.utils
import
set_weight_attrs
apply_fp8_linear
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
)
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
...
...
@@ -85,6 +86,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
...
...
@@ -95,20 +97,21 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
layer
.
orig_dtype
=
params_dtype
# WEIGHT
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
requires_grad
=
False
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
**
extra_weight_attrs
,
})
# WEIGHT SCALE
weight_scale
=
create_per_channel_scale_param
(
output_partition_sizes
,
**
extra_weight_attrs
)
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
(
(
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE UPPER BOUND
...
...
@@ -118,6 +121,11 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
layer
.
input_scale_ub
=
input_scale_ub
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# required by torch.compile
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
0640f227
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
from
torch.nn
import
Module
...
...
@@ -7,7 +7,8 @@ from torch.nn.parameter import Parameter
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
@@ -332,19 +333,16 @@ class Fp8MoEMethod(FusedMoEMethodBase):
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
# If loading fp8 checkpoint, pass the weight loaders.
# If loading an fp16 checkpoint, do not (we will quantize in
# process_weights_after_loading()
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
set_weight_attrs
(
w13_weight_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
set_weight_attrs
(
w2_weight_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
quant_config
.
activation_scheme
==
"static"
:
...
...
@@ -357,19 +355,14 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
set_weight_attrs
(
w13_input_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
set_weight_attrs
(
w13_input_scale
,
extra_weight_attrs
)
w2_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
set_weight_attrs
(
w2_input_scale
,
{
"is_fp8_scale"
:
True
,
**
extra_weight_attrs
})
set_weight_attrs
(
w2_input_scale
,
extra_weight_attrs
)
else
:
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
...
...
@@ -475,15 +468,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
requires_grad
=
False
)
return
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
...
...
@@ -494,7 +490,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
)
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
...
...
vllm/model_executor/layers/quantization/gptq.py
View file @
0640f227
...
...
@@ -11,7 +11,11 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
RowvLLMParameter
)
class
GPTQConfig
(
QuantizationConfig
):
...
...
@@ -108,6 +112,7 @@ class GPTQLinearMethod(LinearMethodBase):
**
extra_weight_attrs
,
):
del
output_size
# Unused.
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
if
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
...
...
@@ -138,73 +143,81 @@ class GPTQLinearMethod(LinearMethodBase):
scale_and_zero_size
=
input_size_per_partition
//
group_size
scale_and_zero_input_dim
=
0
qweight
=
Parameter
(
torch
.
empty
(
qweight
=
PackedvLLM
Parameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
g_idx
=
Parameter
(
torch
.
tensor
(
[
i
//
self
.
quant_config
.
group_size
for
i
in
range
(
input_size_per_partition
)
],
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs
(
g_idx
,
{
"input_dim"
:
0
,
"ignore_warning"
:
True
})
qzeros
=
Parameter
(
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
0
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
weight_loader
=
weight_loader
)
g_idx
=
RowvLLMParameter
(
data
=
torch
.
tensor
(
[
i
//
self
.
quant_config
.
group_size
for
i
in
range
(
input_size_per_partition
)
],
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
weight_loader
=
weight_loader
)
qzeros_args
=
{
"data"
:
torch
.
empty
(
scale_and_zero_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
"input_dim"
:
scale_and_zero_input_dim
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
scales
=
Parameter
(
"weight_loader"
:
weight_loader
}
weight_scale_args
=
{
"data"
:
torch
.
empty
(
scale_and_zero_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"input_dim"
:
scale_and_zero_input_dim
,
"output_dim"
:
1
,
})
"weight_loader"
:
weight_loader
}
if
scale_and_zero_input_dim
is
None
:
scales
=
ChannelQuantScaleParameter
(
output_dim
=
1
,
**
weight_scale_args
)
qzeros
=
PackedColumnParameter
(
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
**
qzeros_args
)
else
:
scales
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
weight_scale_args
)
qzeros
=
PackedvLLMParameter
(
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
**
qzeros_args
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
set_weight_attrs
(
g_idx
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
exllama_state
=
exllama_state
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# for torch.compile
layer
.
qweight
=
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
qzeros
=
Parameter
(
layer
.
qzeros
.
data
,
requires_grad
=
False
)
layer
.
qweight
=
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
g_idx
=
Parameter
(
layer
.
g_idx
.
data
,
requires_grad
=
False
)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if
layer
.
exllama_state
==
ExllamaState
.
UNINITIALIZED
:
...
...
vllm/model_executor/layers/quantization/gptq_marlin_24.py
View file @
0640f227
...
...
@@ -8,7 +8,10 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
...
...
@@ -149,7 +152,7 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
**
extra_weight_attrs
,
):
del
output_size
# Unused.
weight_loader
=
extra_weight_attrs
[
"weight_loader"
]
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
...
...
@@ -187,87 +190,80 @@ class GPTQMarlin24LinearMethod(LinearMethodBase):
"Each permutation group must reside on the same gpu"
)
# Quantized 4Bit weights packed into Int32.
qweight
=
Parameter
(
torch
.
empty
(
qweight
=
PackedvLLM
Parameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
tile_size
//
2
,
output_size_per_partition
*
self
.
quant_config
.
tile_size
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
"marlin_tile_size"
:
self
.
quant_config
.
tile_size
,
},
)
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
marlin_tile_size
=
self
.
quant_config
.
tile_size
,
weight_loader
=
weight_loader
)
# Meta
meta
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
8
//
2
//
2
,
output_size_per_partition
*
2
,
device
=
"cuda"
,
dtype
=
torch
.
int16
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
meta
,
{
"input_dim"
:
0
,
"packed_dim"
:
1
,
"pack_factor"
:
1
,
"output_dim"
:
1
,
"marlin_tile_size"
:
2
,
},
)
meta
=
PackedvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
8
//
2
//
2
,
output_size_per_partition
*
2
,
device
=
"cuda"
,
dtype
=
torch
.
int16
,
),
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
1
,
marlin_tile_size
=
2
,
weight_loader
=
weight_loader
)
# Determine if channelwise or not
input_groups
=
(
1
if
self
.
quant_config
.
group_size
==
-
1
else
input_size_per_partition
//
self
.
quant_config
.
group_size
)
scales
=
Parameter
(
weight_scale_args
=
{
"data"
:
torch
.
empty
(
input_groups
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"input_dim"
:
None
if
input_groups
==
1
else
0
,
"output_dim"
:
1
,
},
)
"weight_loader"
:
weight_loader
}
if
input_groups
==
1
:
scales
=
ChannelQuantScaleParameter
(
output_dim
=
1
,
**
weight_scale_args
)
else
:
scales
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
weight_scale_args
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_n_threads
)
*
self
.
quant_config
.
max_parallel
workspace
=
Parameter
(
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
requires_grad
=
False
)
workspace
=
BasevLLMParameter
(
data
=
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"B_24"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"B_meta"
,
meta
)
set_weight_attrs
(
meta
,
extra_weight_attrs
)
layer
.
register_parameter
(
"s"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
register_parameter
(
"workspace"
,
workspace
)
set_weight_attrs
(
workspace
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# required by torch.compile
layer
.
B_24
=
Parameter
(
layer
.
B_24
.
data
,
requires_grad
=
False
)
layer
.
s
=
Parameter
(
layer
.
s
.
data
,
requires_grad
=
False
)
layer
.
B_meta
=
Parameter
(
layer
.
B_meta
.
data
,
requires_grad
=
False
)
layer
.
workspace
=
Parameter
(
layer
.
workspace
.
data
,
requires_grad
=
False
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/quantization/neuron_quant.py
0 → 100644
View file @
0640f227
import
os
from
importlib.util
import
find_spec
from
typing
import
Any
,
Dict
,
List
,
Optional
from
torch.nn
import
Module
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
SUPPORTED_QUANT_DTYPE_LIST
=
[
's8'
,
'f8e4m3fn'
]
class
NeuronQuantConfig
(
QuantizationConfig
):
"""Int8 Quantization Config class for Neuron Backend."""
def
__init__
(
self
,
dequant_dtype
:
str
=
"f16"
,
quantize_method
:
str
=
"vector_dynamic"
,
)
->
None
:
self
.
quant_dtype
=
os
.
getenv
(
"NEURON_QUANT_DTYPE"
,
"s8"
)
if
self
.
quant_dtype
not
in
SUPPORTED_QUANT_DTYPE_LIST
:
raise
ValueError
(
f
"Neuron quantization datatype
{
self
.
quant_dtype
}
is not valid,"
f
"the quantization datatype should match one of the below types"
f
"
{
SUPPORTED_QUANT_DTYPE_LIST
}
"
)
self
.
dequant_dtype
=
dequant_dtype
self
.
quantize_method
=
quantize_method
def
get_name
(
self
)
->
str
:
return
"neuron_quant"
def
get_supported_act_dtypes
(
self
)
->
List
[
str
]:
return
SUPPORTED_QUANT_DTYPE_LIST
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
raise
NotImplementedError
(
"This function should not be called with Neuron Backend"
)
@
staticmethod
def
get_config_filenames
()
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"NeuronQuantConfig"
:
quantize_method
=
cls
.
get_from_keys
(
config
,
[
"quantize_method"
])
dequant_dtype
=
cls
.
get_from_keys
(
config
,
[
"dequant_dtype"
])
return
cls
(
dequant_dtype
=
dequant_dtype
,
quantize_method
=
quantize_method
)
def
get_quant_method
(
self
,
layer
:
Module
,
prefix
:
str
)
->
Optional
[
Any
]:
if
find_spec
(
"transformers_neuronx"
)
is
not
None
:
return
self
.
get_quantization_config
()
else
:
raise
NotImplementedError
(
"Neuron Quantization is only supported through"
" transformers_neuronx."
)
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
def
get_quantization_config
(
self
):
from
transformers_neuronx.config
import
QuantizationConfig
return
QuantizationConfig
(
quant_dtype
=
self
.
quant_dtype
,
dequant_dtype
=
self
.
dequant_dtype
,
quantize_method
=
self
.
quantize_method
)
vllm/model_executor/layers/quantization/qqq.py
View file @
0640f227
...
...
@@ -8,7 +8,10 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedvLLMParameter
)
logger
=
init_logger
(
__name__
)
...
...
@@ -133,6 +136,7 @@ class QQQLinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
weight_loader
=
extra_weight_attrs
[
"weight_loader"
]
if
params_dtype
!=
torch
.
float16
:
raise
ValueError
(
f
"The params dtype must be float16, but got
{
params_dtype
}
"
)
...
...
@@ -170,90 +174,74 @@ class QQQLinearMethod(LinearMethodBase):
"Each permutation group must reside on the same gpu"
)
# Quantized 4Bit weights packed into Int32.
qweight
=
Parameter
(
torch
.
empty
(
qweight
=
PackedvLLM
Parameter
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
tile_size
,
output_size_per_partition
*
self
.
quant_config
.
tile_size
//
self
.
quant_config
.
pack_factor
,
device
=
"cuda"
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
"marlin_tile_size"
:
self
.
quant_config
.
tile_size
,
},
)
s_channel
=
Parameter
(
torch
.
empty
(
1
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
float
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
s_channel
,
{
"input_dim"
:
None
,
"output_dim"
:
1
,
},
)
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
marlin_tile_size
=
self
.
quant_config
.
tile_size
,
weight_loader
=
weight_loader
)
s_channel
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
(
1
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
float
,
),
weight_loader
=
weight_loader
,
output_dim
=
1
)
if
self
.
quant_config
.
group_size
==
-
1
:
s_group
=
Parameter
(
torch
.
tensor
(
[],
device
=
"cuda"
,
dtype
=
torch
.
half
,
),
requires_grad
=
False
,
s_group_data
=
torch
.
tensor
(
[],
device
=
"cuda"
,
dtype
=
torch
.
half
,
)
else
:
s_group
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
half
,
),
requires_grad
=
False
,
s_group_data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
group_size
,
output_size_per_partition
,
device
=
"cuda"
,
dtype
=
torch
.
half
,
)
s
et_weight_attrs
(
s_group
,
{
"input_dim"
:
None
if
self
.
quant_config
.
group_size
==
-
1
else
0
,
"output_dim"
:
None
if
self
.
quant_config
.
group_size
==
-
1
else
1
,
}
,
)
s
_group_attr
=
{
"data"
:
s_group_data
,
"weight_loader"
:
weight_loader
}
if
self
.
quant_config
.
group_size
==
-
1
:
s_group
=
BasevLLMParameter
(
**
s_group_attr
)
else
:
s_group
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
s_group_attr
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size
=
(
output_size_per_partition
//
self
.
quant_config
.
min_n_threads
)
*
self
.
quant_config
.
max_parallel
workspace
=
Parameter
(
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
requires_grad
=
False
)
workspace
=
BasevLLMParameter
(
data
=
torch
.
zeros
(
max_workspace_size
,
device
=
"cuda"
,
dtype
=
torch
.
int
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"B"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"s_channel"
,
s_channel
)
set_weight_attrs
(
s_channel
,
extra_weight_attrs
)
layer
.
register_parameter
(
"s_group"
,
s_group
)
set_weight_attrs
(
s_group
,
extra_weight_attrs
)
layer
.
register_parameter
(
"workspace"
,
workspace
)
set_weight_attrs
(
workspace
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# required by torch.compile
layer
.
B
=
Parameter
(
layer
.
B
.
data
,
requires_grad
=
False
)
layer
.
s_channel
=
Parameter
(
layer
.
s_channel
.
data
,
requires_grad
=
False
)
layer
.
s_group
=
Parameter
(
layer
.
s_group
.
data
,
requires_grad
=
False
)
layer
.
workspace
=
Parameter
(
layer
.
workspace
.
data
,
requires_grad
=
False
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/quantization/tpu_int8.py
View file @
0640f227
...
...
@@ -7,7 +7,7 @@ from torch.nn.parameter import Parameter
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.
utils
import
set_weight_attrs
from
vllm.model_executor.
parameter
import
ModelWeightParameter
ACTIVATION_SCHEMES
=
[
"none"
]
...
...
@@ -64,16 +64,16 @@ class TPUInt8LinearMethod(LinearMethodBase):
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
**
extra_weight_attrs
,
"input_dim"
:
1
,
"output_dim"
:
0
,
})
def
_quantize_weight
(
self
,
weight
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -92,6 +92,7 @@ class TPUInt8LinearMethod(LinearMethodBase):
return
qweight
,
qscale
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
layer
.
weight
=
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
device
=
layer
.
weight
.
device
qweight
,
qscale
=
self
.
_quantize_weight
(
layer
.
weight
)
qweight
=
qweight
.
to
(
device
)
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
0640f227
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
...
...
@@ -38,31 +36,6 @@ def all_close_1d(x: torch.Tensor) -> bool:
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
def
create_per_tensor_scale_param
(
output_partition_sizes
:
List
[
int
],
**
extra_weight_attrs
,
)
->
Parameter
:
scale
=
Parameter
(
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
set_weight_attrs
(
scale
,
{
"needs_scalar_to_array"
:
True
,
**
extra_weight_attrs
})
return
scale
def
create_per_channel_scale_param
(
output_partition_sizes
:
List
[
int
],
**
extra_weight_attrs
)
->
Parameter
:
scale
=
Parameter
(
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
set_weight_attrs
(
scale
,
{
"output_dim"
:
0
,
**
extra_weight_attrs
})
return
scale
def
convert_to_channelwise
(
weight_scale
:
torch
.
Tensor
,
logical_widths
:
List
[
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
vllm/model_executor/layers/rejection_sampler.py
View file @
0640f227
from
functools
import
cached_property
from
importlib.util
import
find_spec
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.jit
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.spec_decode_base_sampler
import
(
SpecDecodeStochasticBaseSampler
)
logger
=
init_logger
(
__name__
)
if
find_spec
(
"flashinfer"
):
"""
Consider utilizing the FlashInfer rejection sampling kernel initially,
as it employs a dedicated kernel rather than relying on
Torch tensor operations. This design choice helps to fuse operations,
reduce memory I/O, and consequently enhances performance.
"""
from
flashinfer.sampling
import
chain_speculative_sampling
else
:
chain_speculative_sampling
=
None
class
RejectionSampler
(
SpecDecodeStochasticBaseSampler
):
"""Apply modified rejection sampling as described in "Accelerating Large
...
...
@@ -16,7 +32,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
def
__init__
(
self
,
disable_bonus_tokens
:
bool
=
True
,
strict_mode
:
bool
=
False
):
strict_mode
:
bool
=
False
,
use_flashinfer
:
Optional
[
bool
]
=
None
):
"""Create a rejection sampler.
Args:
...
...
@@ -26,13 +43,29 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
use_falshinfer: We will use this parameter to determine whether
to use the FlashInfer rejection sampling kernel or not. If it's
None, we will use the default value from the environment variable.
This parameter is only used for testing purposes.
"""
super
().
__init__
(
disable_bonus_tokens
=
disable_bonus_tokens
,
strict_mode
=
strict_mode
)
if
use_flashinfer
is
None
:
self
.
use_flashinfer
=
envs
.
VLLM_USE_FLASHINFER_SAMPLER
and
(
chain_speculative_sampling
is
not
None
)
else
:
self
.
use_flashinfer
=
use_flashinfer
if
self
.
use_flashinfer
:
assert
not
disable_bonus_tokens
,
\
"flashinfer will enable bonus token by default"
logger
.
info
(
"Use flashinfer for rejection sampling."
)
else
:
logger
.
info
(
"Use pytorch for rejection sampling."
)
def
forward
(
self
,
target_probs
:
torch
.
Tensor
,
target_
with_bonus_
probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
...
...
@@ -50,9 +83,9 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
sequence.
Args:
target_probs: The probability distribution
over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
target_
with_bonus_
probs: The probability distribution
over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens
+ 1
, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
...
...
@@ -78,23 +111,52 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if
self
.
_strict_mode
:
self
.
_raise_if_incorrect_input
(
target_probs
,
draft_token_ids
,
bonus_token_ids
,
draft_probs
)
self
.
_raise_if_incorrect_input
(
target_with_bonus_probs
,
draft_token_ids
,
bonus_token_ids
,
draft_probs
)
accepted
,
recovered_token_ids
=
(
self
.
_batch_modified_rejection_sampling
(
target_probs
,
draft_probs
,
draft_token_ids
,
seeded_seqs
,
))
batch_size
,
k
,
_
=
draft_probs
.
shape
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
draft_token_ids
,
bonus_token_ids
,
)
# batch_size = 0 when all requests in the batch are
# non_spec requests. In this case, output_token_ids is
# just an empty tensor.
if
batch_size
==
0
:
return
torch
.
empty
(
0
,
k
+
1
,
device
=
draft_probs
.
device
,
dtype
=
int
)
# If use Flashinfer chain_speculative_sampling kernel
# for rejection sampling
if
self
.
use_flashinfer
:
batch_size
,
k
,
_
=
draft_probs
.
shape
uniform_samples
=
self
.
_create_uniform_samples
(
seeded_seqs
,
batch_size
,
k
,
draft_probs
.
device
)
output_token_ids
,
accepted_token_num
,
emitted_token_num
\
=
chain_speculative_sampling
(
draft_probs
,
draft_token_ids
,
uniform_samples
,
target_with_bonus_probs
)
# num_emitted_tokens returned by flashinfer
# does not include the bonus token
# Flashinfer stops at the first token that violates
# the condition p >= q and does not include recovery/bonus token.
# Therefore, we need to add batch_size here.
self
.
num_accepted_tokens
+=
accepted_token_num
.
sum
()
self
.
num_emitted_tokens
+=
emitted_token_num
.
sum
()
+
batch_size
self
.
num_draft_tokens
+=
batch_size
*
k
else
:
accepted
,
recovered_token_ids
=
(
self
.
_batch_modified_rejection_sampling
(
target_with_bonus_probs
[:,
:
-
1
],
draft_probs
,
draft_token_ids
,
seeded_seqs
,
))
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
draft_token_ids
,
bonus_token_ids
,
)
return
output_token_ids
...
...
@@ -135,6 +197,63 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
return
accepted
,
recovered_token_ids
def
_create_uniform_samples
(
self
,
seeded_seqs
:
Optional
[
Dict
[
int
,
torch
.
Generator
]],
batch_size
:
int
,
k
:
int
,
device
:
torch
.
device
)
->
torch
.
Tensor
:
"""
Generates a batch of uniform random samples, with optional seeding
for specific sequences.
This method creates a tensor of shape `(batch_size, k + 1)` filled
with uniform random values in the range [0, 1). If `seeded_seqs`
is provided, the sequences corresponding to specific indices
will be generated using the provided `torch.Generator` for
reproducibility. The other sequences will be generated without
a seed.
Args:
seeded_seqs : Optional[Dict[int, torch.Generator]]
A dictionary mapping indices in the batch to
`torch.Generator` objects. If `None`, all samples are
generated without a seed.
batch_size : int
The number of sequences to generate.
k : int
The number of random samples per sequence.
device : torch.device
The device on which to allocate the tensor.
Returns:
uniform_rand : torch.Tensor
A tensor of shape `(batch_size, k + 1)` containing uniform
random values in the range [0, 1).
"""
if
not
seeded_seqs
:
return
torch
.
rand
(
batch_size
,
k
+
1
,
device
=
device
)
uniform_rand
=
torch
.
empty
(
batch_size
,
k
+
1
,
device
=
device
)
non_seeded_indices
=
[]
for
idx
in
range
(
batch_size
):
generator
=
seeded_seqs
.
get
(
idx
)
if
generator
is
None
:
non_seeded_indices
.
append
(
idx
)
else
:
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
1
,
k
+
1
,
dtype
=
self
.
probs_dtype
,
device
=
device
,
generator
=
generator
)
if
non_seeded_indices
:
uniform_rand
[
non_seeded_indices
,
:]
=
torch
.
rand
(
len
(
non_seeded_indices
),
k
+
1
,
dtype
=
self
.
probs_dtype
,
device
=
device
)
return
uniform_rand
def
_get_accepted
(
self
,
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
...
...
@@ -175,29 +294,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indicies
,
draft_token_ids
]
if
not
seeded_seqs
:
uniform_rand
=
torch
.
rand_like
(
selected_target_probs
)
else
:
uniform_rand
=
torch
.
empty_like
(
selected_target_probs
)
non_seeded_indices
=
[]
for
idx
in
range
(
batch_size
):
generator
=
seeded_seqs
.
get
(
idx
)
if
generator
is
None
:
non_seeded_indices
.
append
(
idx
)
else
:
uniform_rand
[
idx
,
:]
=
torch
.
rand
(
1
,
k
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
,
generator
=
generator
)
if
non_seeded_indices
:
uniform_rand
[
non_seeded_indices
,
:]
=
torch
.
rand
(
len
(
non_seeded_indices
),
k
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
)
uniform_rand
=
self
.
_create_uniform_samples
(
seeded_seqs
,
batch_size
,
k
-
1
,
target_probs
.
device
)
capped_ratio
=
torch
.
minimum
(
selected_target_probs
/
selected_draft_probs
,
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment