Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4bc4de1
Unverified
Commit
f4bc4de1
authored
Apr 25, 2024
by
Kunshang Ji
Committed by
GitHub
Apr 25, 2024
Browse files
[Core]refactor aqlm quant ops (#4351)
parent
bd7a8eef
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
2 deletions
+16
-2
benchmarks/kernels/benchmark_aqlm.py
benchmarks/kernels/benchmark_aqlm.py
+1
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+14
-0
vllm/model_executor/layers/quantization/aqlm.py
vllm/model_executor/layers/quantization/aqlm.py
+1
-1
No files found.
benchmarks/kernels/benchmark_aqlm.py
View file @
f4bc4de1
...
@@ -6,7 +6,7 @@ from typing import Optional
...
@@ -6,7 +6,7 @@ from typing import Optional
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.aqlm
import
(
from
vllm.model_executor.layers.quantization.aqlm
import
(
dequantize_weight
,
generic_dequantize_gemm
,
get_int_dtype
,
dequantize_weight
,
generic_dequantize_gemm
,
get_int_dtype
,
optimized_dequantize_gemm
)
optimized_dequantize_gemm
)
...
...
vllm/_custom_ops.py
View file @
f4bc4de1
...
@@ -153,6 +153,20 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -153,6 +153,20 @@ def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_n
,
size_k
)
size_n
,
size_k
)
# aqlm
def
aqlm_gemm
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
vllm_ops
.
aqlm_gemm
(
input
,
codes
,
codebooks
,
scales
,
codebook_partition_sizes
,
bias
)
def
aqlm_dequant
(
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
codebook_partition_sizes
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
vllm_ops
.
aqlm_dequant
(
codes
,
codebooks
,
codebook_partition_sizes
)
# fp8
# fp8
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
...
...
vllm/model_executor/layers/quantization/aqlm.py
View file @
f4bc4de1
...
@@ -8,7 +8,7 @@ import torch
...
@@ -8,7 +8,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm
._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
set_weight_attrs
)
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment