Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
23f32229
Unverified
Commit
23f32229
authored
Sep 06, 2024
by
Dipika Sikka
Committed by
GitHub
Sep 06, 2024
Browse files
[Misc] Remove `SqueezeLLM` (#8220)
parent
9db52eab
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
6 additions
and
389 deletions
+6
-389
CMakeLists.txt
CMakeLists.txt
+0
-1
csrc/ops.h
csrc/ops.h
+0
-3
csrc/quantization/squeezellm/quant_cuda_kernel.cu
csrc/quantization/squeezellm/quant_cuda_kernel.cu
+0
-216
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+0
-6
docs/source/quantization/supported_hardware.rst
docs/source/quantization/supported_hardware.rst
+0
-11
examples/fp8/README.md
examples/fp8/README.md
+2
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+0
-6
vllm/config.py
vllm/config.py
+2
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+1
-1
vllm/lora/layers.py
vllm/lora/layers.py
+1
-1
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+0
-2
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+0
-138
No files found.
CMakeLists.txt
View file @
23f32229
...
...
@@ -181,7 +181,6 @@ set(VLLM_EXT_SRC
"csrc/pos_encoding_kernels.cu"
"csrc/activation_kernels.cu"
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
...
...
csrc/ops.h
View file @
23f32229
...
...
@@ -170,9 +170,6 @@ void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
void
dynamic_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
scales
);
void
squeezellm_gemm
(
torch
::
Tensor
vec
,
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
lookup_table
);
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
...
...
csrc/quantization/squeezellm/quant_cuda_kernel.cu
deleted
100644 → 0
View file @
9db52eab
#include <torch/all.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
#include <c10/cuda/CUDAGuard.h>
#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16
namespace
vllm
{
namespace
squeezellm
{
__device__
inline
unsigned
int
as_unsigned
(
int
i
)
{
return
*
reinterpret_cast
<
unsigned
int
*>
(
&
i
);
}
// 4-bit matvec kernel (LUT-based)
__global__
void
NUQ4MatMulKernel
(
#ifndef USE_ROCM
const
half2
*
__restrict__
vec
,
#else
const
__half2
*
__restrict__
vec
,
#endif
const
int
*
__restrict__
mat
,
#ifndef USE_ROCM
half2
*
__restrict__
mul
,
#else
float2
*
__restrict__
mul
,
#endif
const
__half
*
__restrict__
lookup_table
,
int
height
,
int
width
,
int
batch
,
int
vec_height
)
{
const
int
blockwidth2
=
BLOCKWIDTH
/
2
;
int
row
=
BLOCKHEIGHT4
*
blockIdx
.
x
;
int
col
=
BLOCKWIDTH
*
blockIdx
.
y
+
threadIdx
.
x
;
#ifndef USE_ROCM
__shared__
half2
blockvec
[
blockwidth2
];
#else
__shared__
__half2
blockvec
[
blockwidth2
];
#endif
__shared__
__half
deq2
[
16
][
BLOCKWIDTH
];
int
off
=
threadIdx
.
x
;
int
column_offset
=
col
*
16
;
for
(
int
val
=
0
;
val
<
16
;
val
+=
1
)
{
int
lut_index
=
column_offset
+
val
;
deq2
[
val
][
off
]
=
lookup_table
[
lut_index
];
}
__half
res
;
#ifndef USE_ROCM
half2
res2
;
half2
tmp2
;
#else
__half2
res2
;
__half2
tmp2
;
#endif
int
i
;
int
k
;
unsigned
int
tmp1
;
unsigned
int
lut_index1
,
lut_index2
;
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
i
=
width
*
row
+
col
;
res
=
__int2half_rd
(
0
);
k
=
0
;
__syncthreads
();
if
(
threadIdx
.
x
<
blockwidth2
)
blockvec
[
threadIdx
.
x
]
=
vec
[
b
*
vec_height
/
2
+
(
row
/
BLOCKHEIGHT4
)
*
blockwidth2
+
threadIdx
.
x
];
__syncthreads
();
while
(
k
<
blockwidth2
)
{
tmp1
=
as_unsigned
(
mat
[
i
]);
#ifndef USE_ROCM
res2
=
{};
tmp2
=
{};
#else
res2
.
x
=
__half_as_ushort
(
__float2half
(
0
));
res2
.
y
=
__half_as_ushort
(
__float2half
(
0
));
tmp2
.
x
=
__half_as_ushort
(
__float2half
(
0
));
tmp2
.
y
=
__half_as_ushort
(
__float2half
(
0
));
#endif
lut_index1
=
tmp1
&
0xF
;
lut_index2
=
(
tmp1
>>
4
)
&
0xF
;
#ifndef USE_ROCM
tmp2
.
x
=
deq2
[
lut_index1
][
off
];
tmp2
.
y
=
deq2
[
lut_index2
][
off
];
#else
tmp2
.
x
=
__half_as_ushort
(
deq2
[
lut_index1
][
off
]);
tmp2
.
y
=
__half_as_ushort
(
deq2
[
lut_index2
][
off
]);
#endif
res2
=
__hfma2
(
tmp2
,
blockvec
[
k
+
0
],
res2
);
lut_index1
=
(
tmp1
>>
8
)
&
0xF
;
lut_index2
=
(
tmp1
>>
12
)
&
0xF
;
#ifndef USE_ROCM
tmp2
.
x
=
deq2
[
lut_index1
][
off
];
tmp2
.
y
=
deq2
[
lut_index2
][
off
];
#else
tmp2
.
x
=
__half_as_ushort
(
deq2
[
lut_index1
][
off
]);
tmp2
.
y
=
__half_as_ushort
(
deq2
[
lut_index2
][
off
]);
#endif
res2
=
__hfma2
(
tmp2
,
blockvec
[
k
+
1
],
res2
);
lut_index1
=
(
tmp1
>>
16
)
&
0xF
;
lut_index2
=
(
tmp1
>>
20
)
&
0xF
;
#ifndef USE_ROCM
tmp2
.
x
=
deq2
[
lut_index1
][
off
];
tmp2
.
y
=
deq2
[
lut_index2
][
off
];
#else
tmp2
.
x
=
__half_as_ushort
(
deq2
[
lut_index1
][
off
]);
tmp2
.
y
=
__half_as_ushort
(
deq2
[
lut_index2
][
off
]);
#endif
res2
=
__hfma2
(
tmp2
,
blockvec
[
k
+
2
],
res2
);
lut_index1
=
(
tmp1
>>
24
)
&
0xF
;
lut_index2
=
(
tmp1
>>
28
)
&
0xF
;
#ifndef USE_ROCM
tmp2
.
x
=
deq2
[
lut_index1
][
off
];
tmp2
.
y
=
deq2
[
lut_index2
][
off
];
#else
tmp2
.
x
=
__half_as_ushort
(
deq2
[
lut_index1
][
off
]);
tmp2
.
y
=
__half_as_ushort
(
deq2
[
lut_index2
][
off
]);
#endif
res2
=
__hfma2
(
tmp2
,
blockvec
[
k
+
3
],
res2
);
#ifndef USE_ROCM
res
=
__hadd
(
__hadd
(
res2
.
x
,
res2
.
y
),
res
);
#else
res
=
__hadd
(
__hadd
(
__ushort_as_half
(
res2
.
x
),
__ushort_as_half
(
res2
.
y
)),
res
);
#endif
i
+=
width
;
k
+=
4
;
}
// col%2 -> only set one of the two values
#ifndef USE_ROCM
half2
res3
=
{};
if
(
col
%
2
==
0
)
{
res3
.
x
=
res
;
}
else
{
res3
.
y
=
res
;
}
#else
__half2
res3
;
res3
.
x
=
__half_as_ushort
(
__float2half
(
0
));
res3
.
y
=
__half_as_ushort
(
__float2half
(
0
));
if
(
col
%
2
==
0
)
{
res3
.
x
=
__half_as_ushort
(
res
);
}
else
{
res3
.
y
=
__half_as_ushort
(
res
);
}
#endif
#ifndef USE_ROCM
atomicAdd
(
&
mul
[
b
*
width
/
2
+
col
/
2
],
res3
);
#else
int
tmp_addr
=
b
*
width
/
2
+
col
/
2
;
atomicAdd
(
&
(
mul
[
tmp_addr
].
x
),
__half2float
(
__ushort_as_half
(
res3
.
x
)));
atomicAdd
(
&
(
mul
[
tmp_addr
].
y
),
__half2float
(
__ushort_as_half
(
res3
.
y
)));
#endif
}
}
}
// namespace squeezellm
}
// namespace vllm
// 4-bit matvec kernel (LUT-based)
void
squeezellm_gemm
(
torch
::
Tensor
vec
,
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
lookup_table
)
{
int
height
=
mat
.
size
(
0
);
int
width
=
mat
.
size
(
1
);
int
batch
=
vec
.
size
(
0
);
int
vec_height
=
vec
.
size
(
1
);
dim3
blocks
((
height
+
BLOCKHEIGHT4
-
1
)
/
BLOCKHEIGHT4
,
(
width
+
BLOCKWIDTH
-
1
)
/
BLOCKWIDTH
);
dim3
threads
(
BLOCKWIDTH
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
vec
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
vllm
::
squeezellm
::
NUQ4MatMulKernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
#ifndef USE_ROCM
(
half2
*
)
vec
.
data_ptr
<
at
::
Half
>
(),
#else
(
__half2
*
)
vec
.
data_ptr
<
at
::
Half
>
(),
#endif
mat
.
data_ptr
<
int
>
(),
#ifndef USE_ROCM
(
half2
*
)
mul
.
data_ptr
<
at
::
Half
>
(),
(
__half
*
)
lookup_table
.
data_ptr
<
at
::
Half
>
(),
#else
(
float2
*
)
mul
.
data_ptr
<
float
>
(),
(
__half
*
)
lookup_table
.
data_ptr
<
at
::
Half
>
(),
#endif
height
,
width
,
batch
,
vec_height
);
}
#undef BLOCKWIDTH
#undef BLOCKHEIGHT4
csrc/torch_bindings.cpp
View file @
23f32229
...
...
@@ -237,12 +237,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"
);
ops
.
impl
(
"gptq_shuffle"
,
torch
::
kCUDA
,
&
gptq_shuffle
);
// Quantized GEMM for SqueezeLLM.
ops
.
def
(
"squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor "
"lookup_table) -> ()"
);
ops
.
impl
(
"squeezellm_gemm"
,
torch
::
kCUDA
,
&
squeezellm_gemm
);
// Compute FP8 quantized tensor for given scaling factor.
ops
.
def
(
"static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"
);
...
...
docs/source/quantization/supported_hardware.rst
View file @
23f32229
...
...
@@ -119,17 +119,6 @@ The table below shows the compatibility of various quantization implementations
- ✗
- ✗
- ✗
* - SqueezeLLM
- ✅︎
- ✅︎
- ✅︎
- ✅︎
- ✅︎
- ✗
- ✗
- ✗
- ✗
- ✗
Notes:
^^^^^^
...
...
examples/fp8/README.md
View file @
23f32229
...
...
@@ -62,7 +62,7 @@ This script evaluates the inference throughput of language models using various
python3
benchmarks
/
benchmark_throughput
.
py
--
help
usage
:
benchmark_throughput
.
py
[
-
h
]
[
--
backend
{
vllm
,
hf
,
mii
}]
[
--
dataset
DATASET
]
[
--
input
-
len
INPUT_LEN
]
[
--
output
-
len
OUTPUT_LEN
]
[
--
model
MODEL
]
[
--
tokenizer
TOKENIZER
]
[
--
quantization
{
awq
,
gptq
,
squeezellm
,
None
}]
[
--
tensor
-
parallel
-
size
TENSOR_PARALLEL_SIZE
]
[
--
n
N
]
[
--
tokenizer
TOKENIZER
]
[
--
quantization
{
awq
,
gptq
,
None
}]
[
--
tensor
-
parallel
-
size
TENSOR_PARALLEL_SIZE
]
[
--
n
N
]
[
--
use
-
beam
-
search
]
[
--
num
-
prompts
NUM_PROMPTS
]
[
--
seed
SEED
]
[
--
hf
-
max
-
batch
-
size
HF_MAX_BATCH_SIZE
]
[
--
trust
-
remote
-
code
]
[
--
max
-
model
-
len
MAX_MODEL_LEN
]
[
--
dtype
{
auto
,
half
,
float16
,
bfloat16
,
float
,
float32
}]
[
--
enforce
-
eager
]
[
--
kv
-
cache
-
dtype
{
auto
,
fp8
}]
[
--
quantization
-
param
-
path
KV_CACHE_quantization_param_path
]
...
...
@@ -76,7 +76,7 @@ optional arguments:
--
output
-
len
OUTPUT_LEN
Output
length
for
each
request
.
Overrides
the
output
length
from
the
dataset
.
--
model
MODEL
--
tokenizer
TOKENIZER
--
quantization
{
awq
,
gptq
,
squeezellm
,
None
},
-
q
{
awq
,
gptq
,
squeezellm
,
None
}
--
quantization
{
awq
,
gptq
,
None
},
-
q
{
awq
,
gptq
,
None
}
--
tensor
-
parallel
-
size
TENSOR_PARALLEL_SIZE
,
-
tp
TENSOR_PARALLEL_SIZE
--
n
N
Number
of
generated
sequences
per
prompt
.
--
use
-
beam
-
search
...
...
vllm/_custom_ops.py
View file @
23f32229
...
...
@@ -209,12 +209,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
torch
.
ops
.
_C
.
gptq_shuffle
(
q_weight
,
q_perm
,
bit
)
# squeezellm
def
squeezellm_gemm
(
vec
:
torch
.
Tensor
,
mat
:
torch
.
Tensor
,
mul
:
torch
.
Tensor
,
lookup_table
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
squeezellm_gemm
(
vec
,
mat
,
mul
,
lookup_table
)
# marlin
def
marlin_gemm
(
a
:
torch
.
Tensor
,
b_q_weight
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
workspace
:
torch
.
Tensor
,
size_m
:
int
,
...
...
vllm/config.py
View file @
23f32229
...
...
@@ -277,7 +277,7 @@ class ModelConfig:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
rocm_supported_quantization
=
[
"awq"
,
"gptq"
,
"squeezellm"
,
"fp8"
]
rocm_supported_quantization
=
[
"awq"
,
"gptq"
,
"fp8"
]
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
...
...
@@ -1537,7 +1537,7 @@ class LoRAConfig:
if
model_config
.
quantization
and
model_config
.
quantization
not
in
[
"awq"
,
"gptq"
]:
# TODO support marlin
and squeezellm
# TODO support marlin
logger
.
warning
(
"%s quantization is not tested with LoRA yet."
,
model_config
.
quantization
)
...
...
vllm/entrypoints/llm.py
View file @
23f32229
...
...
@@ -55,7 +55,7 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq", "gptq",
"squeezellm",
and "fp8" (experimental).
we support "awq", "gptq", and "fp8" (experimental).
If None, we first check the `quantization_config` attribute in the
model config file. If that is None, we assume the model weights are
not quantized and use `dtype` to determine the data type of
...
...
vllm/lora/layers.py
View file @
23f32229
...
...
@@ -39,7 +39,7 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
# unquantizedLinear
if
hasattr
(
base_layer
,
"weight"
):
return
base_layer
.
weight
.
device
# GPTQ/AWQ
/SqueezeLLM
# GPTQ/AWQ
elif
hasattr
(
base_layer
,
"qweight"
):
return
base_layer
.
qweight
.
device
# marlin
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
23f32229
...
...
@@ -25,7 +25,6 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from
vllm.model_executor.layers.quantization.neuron_quant
import
(
NeuronQuantConfig
)
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
...
...
@@ -43,7 +42,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"gptq_marlin"
:
GPTQMarlinConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"bitsandbytes"
:
BitsAndBytesConfig
,
"qqq"
:
QQQConfig
,
...
...
vllm/model_executor/layers/quantization/squeezellm.py
deleted
100644 → 0
View file @
9db52eab
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_hip
class
SqueezeLLMConfig
(
QuantizationConfig
):
"""Config class for SqueezeLLM.
Reference: https://arxiv.org/pdf/2306.07629
"""
def
__init__
(
self
,
weight_bits
:
int
,
)
->
None
:
self
.
weight_bits
=
weight_bits
if
self
.
weight_bits
!=
4
:
raise
ValueError
(
"Currently, only 4-bit weight quantization is supported for "
f
"SqueezeLLM, but got
{
self
.
weight_bits
}
bits."
)
self
.
pack_factor
=
32
//
self
.
weight_bits
def
__repr__
(
self
)
->
str
:
return
f
"SqueezeLLMConfig(weight_bits=
{
self
.
weight_bits
}
)"
def
get_name
(
self
)
->
str
:
return
"squeezellm"
def
get_supported_act_dtypes
(
self
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
70
@
staticmethod
def
get_config_filenames
()
->
List
[
str
]:
return
[
"quant_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"SqueezeLLMConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"wbits"
])
return
cls
(
weight_bits
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
if
isinstance
(
layer
,
LinearBase
):
return
SqueezeLLMLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SqueezeLLMLinearMethod
(
QuantizeMethodBase
):
"""Linear method for SqueezeLLM.
Args:
quant_config: The SqueezeLLM quantization config.
"""
def
__init__
(
self
,
quant_config
:
SqueezeLLMConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
if
input_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size."
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
lookup_table
=
Parameter
(
torch
.
empty
(
output_size
,
self
.
quant_config
.
weight_bits
**
2
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
lookup_table
,
{
"output_dim"
:
0
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"lookup_table"
,
lookup_table
)
set_weight_attrs
(
lookup_table
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
lookup_table
=
layer
.
lookup_table
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
is_hip
():
out_f
=
torch
.
zeros
(
out_shape
,
dtype
=
torch
.
float
)
ops
.
squeezellm_gemm
(
reshaped_x
,
qweight
,
out_f
,
lookup_table
)
out
=
out_f
.
to
(
dtype
=
torch
.
float16
)
else
:
# NOTE: The output tensor should be zero-initialized.
out
=
torch
.
zeros
(
out_shape
,
dtype
=
torch
.
float16
)
ops
.
squeezellm_gemm
(
reshaped_x
,
qweight
,
out
,
lookup_table
)
if
bias
is
not
None
:
out
.
add_
(
bias
)
return
out
.
reshape
(
out_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment