Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
d077150e
Unverified
Commit
d077150e
authored
Dec 18, 2023
by
OlivierDehaene
Committed by
GitHub
Dec 18, 2023
Browse files
fix: fix gpt-q with groupsize = -1 (#1358)
parent
8428ed10
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
43 additions
and
45 deletions
+43
-45
proto/generate.proto
proto/generate.proto
+3
-0
router/client/src/client.rs
router/client/src/client.rs
+7
-1
server/text_generation_server/server.py
server/text_generation_server/server.py
+24
-17
server/text_generation_server/utils/gptq/exllama.py
server/text_generation_server/utils/gptq/exllama.py
+2
-9
server/text_generation_server/utils/gptq/exllamav2.py
server/text_generation_server/utils/gptq/exllamav2.py
+1
-12
server/text_generation_server/utils/weights.py
server/text_generation_server/utils/weights.py
+6
-6
No files found.
proto/generate.proto
View file @
d077150e
...
@@ -213,6 +213,9 @@ message DecodeResponse {
...
@@ -213,6 +213,9 @@ message DecodeResponse {
message
WarmupRequest
{
message
WarmupRequest
{
/// Batch to warmup on
/// Batch to warmup on
Batch
batch
=
1
;
Batch
batch
=
1
;
uint32
max_input_length
=
2
;
uint32
max_prefill_tokens
=
3
;
uint32
max_total_tokens
=
4
;
}
}
/// Empty response
/// Empty response
...
...
router/client/src/client.rs
View file @
d077150e
...
@@ -145,7 +145,13 @@ impl Client {
...
@@ -145,7 +145,13 @@ impl Client {
max_tokens
:
0
,
max_tokens
:
0
,
};
};
let
request
=
tonic
::
Request
::
new
(
WarmupRequest
{
batch
:
Some
(
batch
)
})
.inject_context
();
let
request
=
tonic
::
Request
::
new
(
WarmupRequest
{
batch
:
Some
(
batch
),
max_input_length
,
max_prefill_tokens
,
max_total_tokens
,
})
.inject_context
();
let
response
=
self
.stub
.warmup
(
request
)
.await
?
.into_inner
();
let
response
=
self
.stub
.warmup
(
request
)
.await
?
.into_inner
();
Ok
(
response
.max_supported_total_tokens
)
Ok
(
response
.max_supported_total_tokens
)
}
}
...
...
server/text_generation_server/server.py
View file @
d077150e
...
@@ -19,9 +19,16 @@ from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
...
@@ -19,9 +19,16 @@ from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
class
TextGenerationService
(
generate_pb2_grpc
.
TextGenerationServiceServicer
):
class
TextGenerationService
(
generate_pb2_grpc
.
TextGenerationServiceServicer
):
def
__init__
(
self
,
model
:
Model
,
cache
:
Cache
,
server_urls
:
List
[
str
]):
def
__init__
(
self
,
model
:
Model
,
cache
:
Cache
,
quantize
:
Optional
[
str
],
server_urls
:
List
[
str
],
):
self
.
cache
=
cache
self
.
cache
=
cache
self
.
model
=
model
self
.
model
=
model
self
.
quantize
=
quantize
self
.
server_urls
=
server_urls
self
.
server_urls
=
server_urls
# For some reason, inference_mode does not work well with GLOO which we use on CPU
# For some reason, inference_mode does not work well with GLOO which we use on CPU
if
model
.
device
.
type
==
"cuda"
:
if
model
.
device
.
type
==
"cuda"
:
...
@@ -56,6 +63,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
...
@@ -56,6 +63,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return
generate_pb2
.
FilterBatchResponse
(
batch
=
filtered_batch
.
to_pb
())
return
generate_pb2
.
FilterBatchResponse
(
batch
=
filtered_batch
.
to_pb
())
async
def
Warmup
(
self
,
request
,
context
):
async
def
Warmup
(
self
,
request
,
context
):
if
self
.
quantize
==
"gptq"
:
try
:
# When using GPTQ, Exllama kernels need some global kernels
# For which we have the finale shapes only after the model has loaded
# This will allocate those buffers.
from
text_generation_server.utils.layers
import
(
create_exllama_buffers
,
set_device
,
)
set_device
(
self
.
model
.
device
)
create_exllama_buffers
(
request
.
max_prefill_tokens
)
except
ImportError
:
pass
if
(
if
(
self
.
model
.
batch_type
==
IdeficsCausalLMBatch
self
.
model
.
batch_type
==
IdeficsCausalLMBatch
):
# Hack, i would rather use kwargs in the `from_pb` call
):
# Hack, i would rather use kwargs in the `from_pb` call
...
@@ -184,21 +206,6 @@ def serve(
...
@@ -184,21 +206,6 @@ def serve(
logger
.
exception
(
"Error when initializing model"
)
logger
.
exception
(
"Error when initializing model"
)
raise
raise
if
quantize
==
"gptq"
:
try
:
# When using GPTQ, Exllama kernels need some global kernels
# For which we have the finale shapes only after the model has loaded
# This will allocate those buffers.
from
text_generation_server.utils.layers
import
(
create_exllama_buffers
,
set_device
,
)
set_device
(
model
.
device
)
create_exllama_buffers
()
except
ImportError
:
pass
server
=
aio
.
server
(
server
=
aio
.
server
(
interceptors
=
[
interceptors
=
[
ExceptionInterceptor
(),
ExceptionInterceptor
(),
...
@@ -206,7 +213,7 @@ def serve(
...
@@ -206,7 +213,7 @@ def serve(
]
]
)
)
generate_pb2_grpc
.
add_TextGenerationServiceServicer_to_server
(
generate_pb2_grpc
.
add_TextGenerationServiceServicer_to_server
(
TextGenerationService
(
model
,
Cache
(),
server_urls
),
server
TextGenerationService
(
model
,
Cache
(),
quantize
,
server_urls
),
server
)
)
SERVICE_NAMES
=
(
SERVICE_NAMES
=
(
generate_pb2
.
DESCRIPTOR
.
services_by_name
[
"TextGenerationService"
].
full_name
,
generate_pb2
.
DESCRIPTOR
.
services_by_name
[
"TextGenerationService"
].
full_name
,
...
...
server/text_generation_server/utils/gptq/exllama.py
View file @
d077150e
...
@@ -37,19 +37,12 @@ def set_device(device):
...
@@ -37,19 +37,12 @@ def set_device(device):
DEVICE
=
device
DEVICE
=
device
def
create_exllama_buffers
():
def
create_exllama_buffers
(
max_total_tokens
:
int
):
global
MAX_DQ
,
MAX_INNER
,
ACT_ORDER
,
DEVICE
,
TEMP_STATE
,
TEMP_DQ
global
MAX_DQ
,
MAX_INNER
,
ACT_ORDER
,
DEVICE
,
TEMP_STATE
,
TEMP_DQ
assert
DEVICE
is
not
None
,
"call set_device first"
assert
DEVICE
is
not
None
,
"call set_device first"
if
ACT_ORDER
:
if
not
ACT_ORDER
:
# TODO: this should be set to rust side `max_total_tokens`, but TGI
# does not offer an API to expose this variable to python, as this variable
# is handled by the client but it appears the model is initialized by the server.
# An alternative could be to initialize the buffers during warmup.
# Dummy
max_total_tokens
=
2048
else
:
max_total_tokens
=
1
max_total_tokens
=
1
# This temp_state buffer is required to reorder X in the act-order case.
# This temp_state buffer is required to reorder X in the act-order case.
...
...
server/text_generation_server/utils/gptq/exllamav2.py
View file @
d077150e
...
@@ -101,7 +101,7 @@ def set_device(device):
...
@@ -101,7 +101,7 @@ def set_device(device):
DEVICE
=
device
DEVICE
=
device
def
create_exllama_buffers
():
def
create_exllama_buffers
(
max_total_tokens
:
int
):
global
FIXED_BYTES
,
LAYERS
,
DEVICE
global
FIXED_BYTES
,
LAYERS
,
DEVICE
temp_dq
=
ExLlamaV2DeviceTensors
(
DEVICE
,
FIXED_BYTES
)
temp_dq
=
ExLlamaV2DeviceTensors
(
DEVICE
,
FIXED_BYTES
)
...
@@ -138,17 +138,6 @@ class QuantLinear(nn.Module):
...
@@ -138,17 +138,6 @@ class QuantLinear(nn.Module):
self
.
bias
=
bias
if
bias
is
not
None
else
None
self
.
bias
=
bias
if
bias
is
not
None
else
None
self
.
group_size
=
groupsize
self
.
group_size
=
groupsize
infeatures
=
self
.
infeatures
outfeatures
=
self
.
outfeatures
assert
qweight
.
shape
==
(
infeatures
//
32
*
self
.
bits
,
outfeatures
)
assert
infeatures
%
self
.
group_size
==
0
assert
qzeros
.
shape
==
(
infeatures
//
self
.
group_size
,
outfeatures
//
32
*
self
.
bits
,
)
assert
scales
.
shape
==
(
infeatures
//
self
.
group_size
,
outfeatures
)
assert
g_idx
.
shape
==
(
infeatures
,),
f
"
{
g_idx
.
shape
}
,
{
infeatures
}
"
global
FIXED_BYTES
,
LAYERS
global
FIXED_BYTES
,
LAYERS
FIXED_BYTES
=
max
(
FIXED_BYTES
,
self
.
scratch_space_fixed
())
FIXED_BYTES
=
max
(
FIXED_BYTES
,
self
.
scratch_space_fixed
())
LAYERS
.
append
(
self
)
LAYERS
.
append
(
self
)
...
...
server/text_generation_server/utils/weights.py
View file @
d077150e
...
@@ -281,18 +281,18 @@ class Weights:
...
@@ -281,18 +281,18 @@ class Weights:
else
:
else
:
logger
.
info
(
f
"Using exllama kernels v
{
HAS_EXLLAMA
}
"
)
logger
.
info
(
f
"Using exllama kernels v
{
HAS_EXLLAMA
}
"
)
if
use_exllama
:
if
use_exllama
and
groupsize
!=
-
1
:
qzeros
=
self
.
get_sharded
(
f
"
{
prefix
}
.qzeros"
,
dim
=
0
)
qzeros
=
self
.
get_sharded
(
f
"
{
prefix
}
.qzeros"
,
dim
=
0
)
scales
=
self
.
get_sharded
(
f
"
{
prefix
}
.scales"
,
dim
=
0
)
scales
=
self
.
get_sharded
(
f
"
{
prefix
}
.scales"
,
dim
=
0
)
g_idx
=
self
.
get_sharded
(
f
"
{
prefix
}
.g_idx"
,
dim
=
0
)
g_idx
=
g_idx
-
g_idx
[
0
]
else
:
else
:
# The triton kernel reorders the scales/zero points instead of the weight/activation.
# Thus, each rank needs the full qzeros/scales.
qzeros
=
self
.
get_tensor
(
f
"
{
prefix
}
.qzeros"
)
qzeros
=
self
.
get_tensor
(
f
"
{
prefix
}
.qzeros"
)
scales
=
self
.
get_tensor
(
f
"
{
prefix
}
.scales"
)
scales
=
self
.
get_tensor
(
f
"
{
prefix
}
.scales"
)
g_idx
=
self
.
get_sharded
(
f
"
{
prefix
}
.g_idx"
,
dim
=
0
)
g_idx
=
self
.
get_sharded
(
f
"
{
prefix
}
.g_idx"
,
dim
=
0
)
if
use_exllama
:
g_idx
=
g_idx
-
g_idx
[
0
]
weight
=
(
qweight
,
qzeros
,
scales
,
g_idx
,
bits
,
groupsize
,
use_exllama
)
weight
=
(
qweight
,
qzeros
,
scales
,
g_idx
,
bits
,
groupsize
,
use_exllama
)
elif
quantize
==
"awq"
:
elif
quantize
==
"awq"
:
bits
,
groupsize
=
self
.
_get_gptq_params
()
bits
,
groupsize
=
self
.
_get_gptq_params
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment