Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
880a76ee
Unverified
Commit
880a76ee
authored
Apr 12, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 12, 2023
Browse files
feat(server): support sharded santacoder (#167)
parent
5fa8ae04
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
463 additions
and
49 deletions
+463
-49
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+11
-3
server/text_generation_server/models/bloom.py
server/text_generation_server/models/bloom.py
+4
-2
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+1
-1
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+6
-1
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+198
-13
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+6
-8
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+221
-13
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+4
-2
server/text_generation_server/models/gpt_neox.py
server/text_generation_server/models/gpt_neox.py
+4
-2
server/text_generation_server/models/opt.py
server/text_generation_server/models/opt.py
+4
-2
server/text_generation_server/models/t5.py
server/text_generation_server/models/t5.py
+4
-2
No files found.
server/text_generation_server/models/__init__.py
View file @
880a76ee
...
@@ -18,8 +18,11 @@ from text_generation_server.models.t5 import T5Sharded
...
@@ -18,8 +18,11 @@ from text_generation_server.models.t5 import T5Sharded
try
:
try
:
from
text_generation_server.models.flash_neox
import
FlashNeoX
,
FlashNeoXSharded
from
text_generation_server.models.flash_neox
import
FlashNeoX
,
FlashNeoXSharded
from
text_generation_server.models.flash_santacoder
import
FlashSantacoder
from
text_generation_server.models.flash_llama
import
FlashLlama
,
FlashLlamaSharded
from
text_generation_server.models.flash_llama
import
FlashLlama
,
FlashLlamaSharded
from
text_generation_server.models.flash_santacoder
import
(
FlashSantacoder
,
FlashSantacoderSharded
,
)
FLASH_ATTENTION
=
torch
.
cuda
.
is_available
()
FLASH_ATTENTION
=
torch
.
cuda
.
is_available
()
except
ImportError
:
except
ImportError
:
...
@@ -49,6 +52,7 @@ if FLASH_ATTENTION:
...
@@ -49,6 +52,7 @@ if FLASH_ATTENTION:
__all__
.
append
(
FlashNeoX
)
__all__
.
append
(
FlashNeoX
)
__all__
.
append
(
FlashNeoXSharded
)
__all__
.
append
(
FlashNeoXSharded
)
__all__
.
append
(
FlashSantacoder
)
__all__
.
append
(
FlashSantacoder
)
__all__
.
append
(
FlashSantacoderSharded
)
__all__
.
append
(
FlashLlama
)
__all__
.
append
(
FlashLlama
)
__all__
.
append
(
FlashLlamaSharded
)
__all__
.
append
(
FlashLlamaSharded
)
...
@@ -78,9 +82,13 @@ def get_model(
...
@@ -78,9 +82,13 @@ def get_model(
else
:
else
:
return
Galactica
(
model_id
,
revision
,
quantize
=
quantize
)
return
Galactica
(
model_id
,
revision
,
quantize
=
quantize
)
if
"
santa
code
r
"
in
model_id
:
if
"
big
code"
in
model_id
:
if
sharded
:
if
sharded
:
raise
NotImplementedError
(
"sharded is not supported for Santacoder"
)
if
not
FLASH_ATTENTION
:
raise
NotImplementedError
(
FLASH_ATT_ERROR_MESSAGE
.
format
(
f
"Sharded Santacoder"
)
)
return
FlashSantacoderSharded
(
model_id
,
revision
=
revision
)
else
:
else
:
santacoder_cls
=
FlashSantacoder
if
FLASH_ATTENTION
else
SantaCoder
santacoder_cls
=
FlashSantacoder
if
FLASH_ATTENTION
else
SantaCoder
return
santacoder_cls
(
model_id
,
revision
,
quantize
)
return
santacoder_cls
(
model_id
,
revision
,
quantize
)
...
...
server/text_generation_server/models/bloom.py
View file @
880a76ee
...
@@ -93,10 +93,11 @@ class BLOOMSharded(BLOOM):
...
@@ -93,10 +93,11 @@ class BLOOMSharded(BLOOM):
filenames
,
filenames
,
quantize
=
quantize
,
quantize
=
quantize
,
device
=
device
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
self
.
model
=
model
.
eval
()
.
to
(
dtype
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
...
@@ -108,6 +109,7 @@ class BLOOMSharded(BLOOM):
...
@@ -108,6 +109,7 @@ class BLOOMSharded(BLOOM):
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
world_size
:
int
,
world_size
:
int
,
):
):
...
@@ -157,7 +159,7 @@ class BLOOMSharded(BLOOM):
...
@@ -157,7 +159,7 @@ class BLOOMSharded(BLOOM):
f
"Name
{
name
}
-- Current
{
current_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
f
"Name
{
name
}
-- Current
{
current_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
)
tensor
=
tensor
.
contiguous
()
tensor
=
tensor
.
contiguous
()
.
to
(
dtype
)
if
quantize
:
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
if
not
HAS_BITS_AND_BYTES
:
...
...
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
880a76ee
...
@@ -373,7 +373,7 @@ class LlamaMLP(nn.Module):
...
@@ -373,7 +373,7 @@ class LlamaMLP(nn.Module):
x
,
x
,
approximate
=
"tanh"
approximate
=
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
N
one
,
else
"n
one
"
,
)
)
)
)
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
880a76ee
...
@@ -376,7 +376,12 @@ class FlashMLP(nn.Module):
...
@@ -376,7 +376,12 @@ class FlashMLP(nn.Module):
self
.
act
=
(
self
.
act
=
(
ACT2FN
[
act
]
ACT2FN
[
act
]
if
"gelu"
not
in
act
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
)
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
"none"
,
)
)
)
if
process_group
is
None
:
if
process_group
is
None
:
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
880a76ee
import
torch
import
torch
import
torch.distributed
import
torch.distributed
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
...
@@ -65,6 +67,127 @@ class FastLinear(nn.Linear):
...
@@ -65,6 +67,127 @@ class FastLinear(nn.Linear):
return
torch
.
matmul
(
input
,
self
.
weight
)
return
torch
.
matmul
(
input
,
self
.
weight
)
class
TensorParallelColumnLinear
(
FastLinear
):
def
__init__
(
self
,
in_features
,
out_features
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
bias
=
True
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_world_size
=
process_group
.
size
()
assert
out_features
%
self
.
tp_world_size
==
0
out_features
=
out_features
//
self
.
tp_world_size
super
().
__init__
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
class
TensorParallelRowLinear
(
FastLinear
):
def
__init__
(
self
,
in_features
,
out_features
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
reduce
=
True
,
bias
=
True
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_world_size
=
process_group
.
size
()
self
.
reduce
=
reduce
assert
in_features
%
self
.
tp_world_size
==
0
in_features
=
in_features
//
self
.
tp_world_size
super
().
__init__
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
super
(
TensorParallelRowLinear
,
self
).
forward
(
input
)
if
self
.
reduce
:
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
class
TensorParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
reduce
=
True
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.0
,
scale_grad_by_freq
=
False
,
sparse
=
False
,
_weight
=
None
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
self
.
reduce
=
reduce
self
.
original_num_embeddings
=
num_embeddings
assert
num_embeddings
%
self
.
tp_world_size
==
0
block_size
=
num_embeddings
//
self
.
tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self
.
min_id
=
self
.
tp_rank
*
block_size
self
.
max_id
=
(
self
.
tp_rank
+
1
)
*
block_size
# Additional entry that will map to zero
# Used for masking
self
.
null_idx
=
block_size
super
().
__init__
(
block_size
,
embedding_dim
,
padding_idx
=
padding_idx
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
,
_weight
=
_weight
,
device
=
device
,
dtype
=
dtype
,
)
def
add_null_idx
(
self
):
"""Additional 0 entry used for masking"""
self
.
weight
=
nn
.
Parameter
(
F
.
pad
(
self
.
weight
,
(
0
,
0
,
0
,
1
)))
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
# translate for [0, self.max_id - self.min_id[
input
=
torch
.
where
(
(
self
.
min_id
>
input
)
|
(
input
>=
self
.
max_id
),
self
.
null_idx
,
input
-
self
.
min_id
,
)
out
=
super
().
forward
(
input
)
if
self
.
reduce
:
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
class
FlashMQAttention
(
torch
.
nn
.
Module
):
class
FlashMQAttention
(
torch
.
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -80,10 +203,16 @@ class FlashMQAttention(torch.nn.Module):
...
@@ -80,10 +203,16 @@ class FlashMQAttention(torch.nn.Module):
self
.
softmax_scale
=
self
.
head_size
**
(
-
0.5
)
self
.
softmax_scale
=
self
.
head_size
**
(
-
0.5
)
if
process_group
is
None
:
if
process_group
is
None
:
self
.
attn
=
FastLinear
(
hidden_size
,
hidden_size
+
2
*
self
.
head_size
)
self
.
c_
attn
=
FastLinear
(
hidden_size
,
hidden_size
+
2
*
self
.
head_size
)
self
.
c_proj
=
FastLinear
(
hidden_size
,
hidden_size
)
self
.
c_proj
=
FastLinear
(
hidden_size
,
hidden_size
)
else
:
else
:
raise
NotImplementedError
self
.
num_heads
=
self
.
num_heads
//
process_group
.
size
()
self
.
c_attn
=
FastLinear
(
hidden_size
,
self
.
head_size
*
(
self
.
num_heads
+
2
))
self
.
c_proj
=
TensorParallelRowLinear
(
hidden_size
,
hidden_size
,
process_group
=
process_group
,
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -94,10 +223,12 @@ class FlashMQAttention(torch.nn.Module):
...
@@ -94,10 +223,12 @@ class FlashMQAttention(torch.nn.Module):
layer_past_present_indices
,
layer_past_present_indices
,
cu_seqlens_q
,
cu_seqlens_q
,
):
):
qkv
=
self
.
attn
(
hidden_states
)
qkv
=
self
.
c_
attn
(
hidden_states
)
# Split query from key_value
# Split query from key_value
query
,
key_value
=
qkv
.
split
([
self
.
hidden_size
,
2
*
self
.
head_size
],
dim
=
1
)
query
,
key_value
=
qkv
.
split
(
[
self
.
head_size
*
self
.
num_heads
,
2
*
self
.
head_size
],
dim
=
1
)
# Prepare query and key_value for indexing
# Prepare query and key_value for indexing
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
@@ -171,7 +302,7 @@ class MLP(nn.Module):
...
@@ -171,7 +302,7 @@ class MLP(nn.Module):
x
,
x
,
approximate
=
"tanh"
approximate
=
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
N
one
,
else
"n
one
"
,
)
)
)
)
...
@@ -179,7 +310,16 @@ class MLP(nn.Module):
...
@@ -179,7 +310,16 @@ class MLP(nn.Module):
self
.
c_fc
=
FastLinear
(
hidden_size
,
intermediate_size
)
self
.
c_fc
=
FastLinear
(
hidden_size
,
intermediate_size
)
self
.
c_proj
=
FastLinear
(
intermediate_size
,
hidden_size
)
self
.
c_proj
=
FastLinear
(
intermediate_size
,
hidden_size
)
else
:
else
:
raise
NotImplementedError
self
.
c_fc
=
TensorParallelColumnLinear
(
hidden_size
,
intermediate_size
,
process_group
=
process_group
,
)
self
.
c_proj
=
TensorParallelRowLinear
(
intermediate_size
,
hidden_size
,
process_group
=
process_group
,
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
self
.
c_fc
(
hidden_states
)
...
@@ -246,9 +386,28 @@ class FlashSantacoderModel(nn.Module):
...
@@ -246,9 +386,28 @@ class FlashSantacoderModel(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
process_group
=
process_group
self
.
tp_embeddings
=
False
if
process_group
is
not
None
:
if
process_group
is
not
None
:
raise
NotImplementedError
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
if
config
.
vocab_size
%
self
.
tp_world_size
==
0
:
self
.
tp_embeddings
=
True
if
self
.
tp_embeddings
:
self
.
wte
=
TensorParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
reduce
=
False
,
process_group
=
process_group
,
)
self
.
wpe
=
TensorParallelEmbedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
,
reduce
=
False
,
process_group
=
process_group
,
)
else
:
self
.
wte
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
wte
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
...
@@ -273,9 +432,12 @@ class FlashSantacoderModel(nn.Module):
...
@@ -273,9 +432,12 @@ class FlashSantacoderModel(nn.Module):
self
.
num_heads
=
self
.
h
[
0
].
attn
.
num_heads
self
.
num_heads
=
self
.
h
[
0
].
attn
.
num_heads
def
post_load_weights
(
self
):
def
post_load_weights
(
self
):
if
self
.
tp_embeddings
:
self
.
wte
.
add_null_idx
()
self
.
wpe
.
add_null_idx
()
for
layer
in
self
.
h
:
for
layer
in
self
.
h
:
layer
:
Block
layer
:
Block
layer
.
attn
.
attn
.
transpose_weight
()
layer
.
attn
.
c_
attn
.
transpose_weight
()
layer
.
attn
.
c_proj
.
transpose_weight
()
layer
.
attn
.
c_proj
.
transpose_weight
()
layer
.
mlp
.
c_fc
.
transpose_weight
()
layer
.
mlp
.
c_fc
.
transpose_weight
()
layer
.
mlp
.
c_proj
.
transpose_weight
()
layer
.
mlp
.
c_proj
.
transpose_weight
()
...
@@ -289,6 +451,8 @@ class FlashSantacoderModel(nn.Module):
...
@@ -289,6 +451,8 @@ class FlashSantacoderModel(nn.Module):
past_key_values
=
None
,
past_key_values
=
None
,
):
):
hidden_states
=
self
.
wte
(
input_ids
)
+
self
.
wpe
(
position_ids
)
hidden_states
=
self
.
wte
(
input_ids
)
+
self
.
wpe
(
position_ids
)
if
self
.
tp_embeddings
:
torch
.
distributed
.
all_reduce
(
hidden_states
,
group
=
self
.
process_group
)
# Prefill
# Prefill
if
past_key_values
is
None
:
if
past_key_values
is
None
:
...
@@ -335,6 +499,13 @@ class FlashSantacoderForCausalLM(nn.Module):
...
@@ -335,6 +499,13 @@ class FlashSantacoderForCausalLM(nn.Module):
self
.
transformer
=
FlashSantacoderModel
(
config
,
process_group
)
self
.
transformer
=
FlashSantacoderModel
(
config
,
process_group
)
if
self
.
transformer
.
tp_embeddings
:
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
//
process_group
.
size
(),
bias
=
False
,
)
else
:
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
def
post_load_weights
(
self
):
def
post_load_weights
(
self
):
...
@@ -352,4 +523,18 @@ class FlashSantacoderForCausalLM(nn.Module):
...
@@ -352,4 +523,18 @@ class FlashSantacoderForCausalLM(nn.Module):
hidden_states
,
present
=
self
.
transformer
(
hidden_states
,
present
=
self
.
transformer
(
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
)
)
return
self
.
lm_head
(
hidden_states
),
present
logits
=
self
.
lm_head
(
hidden_states
)
if
self
.
transformer
.
tp_embeddings
:
# Logits are sharded, so we need to gather them
world_logits
=
[
torch
.
empty_like
(
logits
)
for
_
in
range
(
self
.
transformer
.
tp_world_size
)
]
torch
.
distributed
.
all_gather
(
world_logits
,
logits
,
group
=
self
.
transformer
.
process_group
)
world_logits
=
torch
.
cat
(
world_logits
,
dim
=
1
)
return
world_logits
,
present
return
logits
,
present
server/text_generation_server/models/flash_neox.py
View file @
880a76ee
...
@@ -5,7 +5,7 @@ from accelerate import init_empty_weights
...
@@ -5,7 +5,7 @@ from accelerate import init_empty_weights
from
opentelemetry
import
trace
from
opentelemetry
import
trace
from
safetensors
import
safe_open
from
safetensors
import
safe_open
from
transformers
import
AutoTokenizer
,
AutoConfig
from
transformers
import
AutoTokenizer
,
AutoConfig
from
typing
import
Optional
,
Tuple
,
List
from
typing
import
Optional
,
List
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_neox_modeling
import
(
from
text_generation_server.models.custom_modeling.flash_neox_modeling
import
(
...
@@ -63,13 +63,13 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -63,13 +63,13 @@ class FlashNeoXSharded(FlashNeoX):
self
.
load_weights
(
self
.
load_weights
(
model
,
model
,
filenames
,
filenames
,
quantize
=
quantize
,
device
=
device
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
model
.
post_load_weights
()
model
.
post_load_weights
()
self
.
model
=
model
.
eval
()
.
to
(
dtype
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -80,16 +80,14 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -80,16 +80,14 @@ class FlashNeoXSharded(FlashNeoX):
def
load_weights
(
def
load_weights
(
model
,
model
,
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
world_size
:
int
,
world_size
:
int
,
):
):
parameters
=
dict
(
model
.
named_parameters
())
parameters
=
dict
(
model
.
named_parameters
())
for
file
in
filenames
:
for
file
in
filenames
:
with
safe_open
(
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
))
as
f
:
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
)
as
f
:
for
name
in
f
.
keys
():
for
name
in
f
.
keys
():
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
module
=
model
.
get_submodule
(
module_name
)
module
=
model
.
get_submodule
(
module_name
)
...
@@ -142,7 +140,7 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -142,7 +140,7 @@ class FlashNeoXSharded(FlashNeoX):
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
)
tensor
=
tensor
.
contiguous
()
tensor
=
tensor
.
contiguous
()
.
to
(
dtype
)
if
current_parameter_tensor
is
not
None
:
if
current_parameter_tensor
is
not
None
:
module
.
_parameters
[
param_name
]
=
tensor
module
.
_parameters
[
param_name
]
=
tensor
...
...
server/text_generation_server/models/flash_santacoder.py
View file @
880a76ee
...
@@ -3,15 +3,20 @@ import torch.distributed
...
@@ -3,15 +3,20 @@ import torch.distributed
from
accelerate
import
init_empty_weights
from
accelerate
import
init_empty_weights
from
opentelemetry
import
trace
from
opentelemetry
import
trace
from
safetensors
import
safe_open
from
pathlib
import
Path
from
pathlib
import
Path
from
transformers
import
AutoTokenizer
,
Auto
Config
from
transformers
import
AutoTokenizer
,
GPT2
Config
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_santacoder_modeling
import
(
from
text_generation_server.models.custom_modeling.flash_santacoder_modeling
import
(
FlashSantacoderForCausalLM
,
FlashSantacoderForCausalLM
,
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
)
)
from
text_generation_server.utils
import
(
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
weight_files
,
weight_files
,
download_weights
,
download_weights
,
weight_hub_files
,
weight_hub_files
,
...
@@ -36,10 +41,9 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -36,10 +41,9 @@ class FlashSantacoder(FlashCausalLM):
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
)
config
=
Auto
Config
.
from_pretrained
(
config
=
GPT2
Config
.
from_pretrained
(
model_id
,
model_id
,
revision
=
revision
,
revision
=
revision
,
trust_remote_code
=
True
,
# Needed as the config is not part of Transformers
)
)
# We do not use from_pretrained as we modified the model internal module layout
# We do not use from_pretrained as we modified the model internal module layout
...
@@ -54,12 +58,9 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -54,12 +58,9 @@ class FlashSantacoder(FlashCausalLM):
model
=
FlashSantacoderForCausalLM
(
config
)
model
=
FlashSantacoderForCausalLM
(
config
)
self
.
load_weights
(
self
.
load_weights
(
model
,
model
,
filenames
,
device
,
dtype
,
config
.
architectures
[
0
].
startswith
(
"GPT2"
)
filenames
,
device
,
dtype
,
)
)
self
.
model
=
model
.
eval
()
.
to
(
device
).
to
(
dtype
)
self
.
model
=
model
.
eval
()
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
...
@@ -71,6 +72,7 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -71,6 +72,7 @@ class FlashSantacoder(FlashCausalLM):
filenames
:
List
[
Path
],
filenames
:
List
[
Path
],
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
transpose
:
bool
,
):
):
for
filename
in
filenames
:
for
filename
in
filenames
:
state_dict
=
torch
.
load
(
filename
,
map_location
=
"cpu"
)
state_dict
=
torch
.
load
(
filename
,
map_location
=
"cpu"
)
...
@@ -81,9 +83,9 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -81,9 +83,9 @@ class FlashSantacoder(FlashCausalLM):
# Fused qkv
# Fused qkv
if
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
:
if
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
:
final_key
=
layer_name
+
".attn.weight"
final_key
=
layer_name
+
".
c_
attn.weight"
elif
"q_attn.bias"
in
key
or
"kv_attn.bias"
in
key
:
elif
"q_attn.bias"
in
key
or
"kv_attn.bias"
in
key
:
final_key
=
layer_name
+
".attn.bias"
final_key
=
layer_name
+
".
c_
attn.bias"
else
:
else
:
final_key
=
key
final_key
=
key
...
@@ -97,18 +99,19 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -97,18 +99,19 @@ class FlashSantacoder(FlashCausalLM):
current_parameter_tensor
=
None
current_parameter_tensor
=
None
if
current_parameter_tensor
is
not
None
:
if
current_parameter_tensor
is
not
None
:
if
(
if
transpose
and
(
"c_fc.weight"
in
key
"c_fc.weight"
in
key
or
"c_proj.weight"
in
key
or
"c_proj.weight"
in
key
or
"q_attn.weight"
in
key
or
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
or
"kv_attn.weight"
in
key
or
"c_attn.weight"
in
key
):
):
# Tranpose as we use nn.Linear instead of Conv1D
# Tranpose as we use nn.Linear instead of Conv1D
value
=
value
.
T
value
=
value
.
T
if
current_parameter_tensor
.
device
==
torch
.
device
(
"meta"
):
if
current_parameter_tensor
.
device
==
torch
.
device
(
"meta"
):
# Init qkv
# Init qkv
if
"attn.weight"
in
final_key
:
if
"
c_
attn.weight"
in
final_key
:
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
(
(
model
.
transformer
.
head_size
model
.
transformer
.
head_size
...
@@ -116,7 +119,7 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -116,7 +119,7 @@ class FlashSantacoder(FlashCausalLM):
value
.
shape
[
1
],
value
.
shape
[
1
],
)
)
)
)
elif
"attn.bias"
in
final_key
:
elif
"
c_
attn.bias"
in
final_key
:
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
(
(
model
.
transformer
.
head_size
model
.
transformer
.
head_size
...
@@ -156,3 +159,208 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -156,3 +159,208 @@ class FlashSantacoder(FlashCausalLM):
return
self
.
tokenizer
.
decode
(
return
self
.
tokenizer
.
decode
(
generated_ids
,
skip_special_tokens
=
False
,
cleanup_tokenization_spaces
=
False
generated_ids
,
skip_special_tokens
=
False
,
cleanup_tokenization_spaces
=
False
)
)
class
FlashSantacoderSharded
(
FlashSantacoder
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashSantacoderSharded is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashSantacoderSharded does not support quantization"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
config
=
GPT2Config
.
from_pretrained
(
model_id
,
revision
=
revision
,
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
with
init_empty_weights
():
model
=
FlashSantacoderForCausalLM
(
config
,
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
self
.
load_weights
(
model
,
filenames
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
transpose
=
config
.
architectures
[
0
].
startswith
(
"GPT2"
),
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
)
@
staticmethod
def
load_weights
(
model
,
filenames
:
List
[
str
],
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
world_size
:
int
,
transpose
:
bool
,
):
for
file
in
filenames
:
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
))
as
f
:
for
key
in
f
.
keys
():
slice_
=
f
.
get_slice
(
key
)
layer_name
=
"."
.
join
(
key
.
split
(
"."
)[:
4
])
# Fused qkv
if
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
:
final_key
=
layer_name
+
".c_attn.weight"
elif
"q_attn.bias"
in
key
or
"kv_attn.bias"
in
key
:
final_key
=
layer_name
+
".c_attn.bias"
else
:
final_key
=
key
module_name
,
param_name
=
final_key
.
rsplit
(
"."
,
1
)
module
=
model
.
get_submodule
(
module_name
)
if
isinstance
(
module
,
TensorParallelColumnLinear
):
dim
=
1
if
transpose
and
"weight"
in
param_name
else
0
size
=
slice_
.
get_shape
()[
dim
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
(
slice_
[
start
:
stop
]
if
dim
==
0
else
slice_
[:,
start
:
stop
]
)
elif
isinstance
(
module
,
TensorParallelRowLinear
):
if
param_name
==
"weight"
:
dim
=
0
if
transpose
else
1
size
=
slice_
.
get_shape
()[
dim
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
(
slice_
[
start
:
stop
]
if
dim
==
0
else
slice_
[:,
start
:
stop
]
)
else
:
tensor
=
slice_
[:]
# XXX: Hack for Rowlinear to add the bias only once.
if
rank
!=
0
:
tensor
=
torch
.
zeros_like
(
tensor
)
elif
isinstance
(
module
,
TensorParallelEmbedding
):
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
elif
key
==
"lm_head.weight"
and
model
.
transformer
.
tp_embeddings
:
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
else
:
try
:
tensor
=
slice_
[:]
except
:
tensor
=
f
.
get_tensor
(
key
)
tensor
=
tensor
.
contiguous
().
to
(
dtype
)
try
:
current_parameter_tensor
=
module
.
_parameters
[
param_name
]
except
KeyError
:
current_parameter_tensor
=
None
if
current_parameter_tensor
is
not
None
:
if
transpose
and
(
"c_fc.weight"
in
key
or
"c_proj.weight"
in
key
or
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
or
"c_attn.weight"
in
key
):
# Tranpose as we use nn.Linear instead of Conv1D
tensor
=
tensor
.
T
if
current_parameter_tensor
.
device
==
torch
.
device
(
"meta"
):
# Init qkv
if
"c_attn.weight"
in
final_key
:
module
.
_parameters
[
param_name
]
=
tensor
.
new_empty
(
(
model
.
transformer
.
head_size
*
(
model
.
transformer
.
num_heads
+
2
),
tensor
.
shape
[
1
],
)
)
elif
"c_attn.bias"
in
final_key
:
module
.
_parameters
[
param_name
]
=
tensor
.
new_empty
(
(
model
.
transformer
.
head_size
*
(
model
.
transformer
.
num_heads
+
2
)
)
)
# Copy to correct slice
if
"q_attn"
in
key
:
size
=
tensor
.
shape
[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
tensor
[
start
:
stop
]
module
.
_parameters
[
param_name
][:
tensor
.
shape
[
0
]]
=
tensor
elif
"kv_attn.weight"
in
key
:
module
.
_parameters
[
param_name
][
model
.
transformer
.
head_size
*
model
.
transformer
.
num_heads
:
]
=
tensor
elif
"kv_attn.bias"
in
key
:
module
.
_parameters
[
param_name
][
model
.
transformer
.
head_size
*
model
.
transformer
.
num_heads
:
]
=
tensor
elif
"c_attn"
in
key
:
# Slice q_tensor by shard
q_tensor
=
tensor
[:
-
2
*
model
.
transformer
.
head_size
]
block_size
=
q_tensor
.
shape
[
0
]
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
q_tensor
=
q_tensor
[
start
:
stop
]
module
.
_parameters
[
param_name
][
:
q_tensor
.
shape
[
0
]
]
=
q_tensor
# Kv tensor is copied for every shard
kv_tensor
=
tensor
[
-
2
*
model
.
transformer
.
head_size
:]
module
.
_parameters
[
param_name
][
q_tensor
.
shape
[
0
]
:
]
=
kv_tensor
else
:
if
current_parameter_tensor
.
shape
!=
tensor
.
shape
:
raise
ValueError
(
f
"Name
{
key
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
module
.
_parameters
[
param_name
]
=
tensor
else
:
module
.
_buffers
[
param_name
]
=
tensor
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
()
server/text_generation_server/models/galactica.py
View file @
880a76ee
...
@@ -219,10 +219,11 @@ class GalacticaSharded(Galactica):
...
@@ -219,10 +219,11 @@ class GalacticaSharded(Galactica):
filenames
,
filenames
,
quantize
=
quantize
,
quantize
=
quantize
,
device
=
device
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
self
.
model
=
model
.
eval
()
.
to
(
dtype
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -235,6 +236,7 @@ class GalacticaSharded(Galactica):
...
@@ -235,6 +236,7 @@ class GalacticaSharded(Galactica):
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
world_size
:
int
,
world_size
:
int
,
):
):
...
@@ -285,7 +287,7 @@ class GalacticaSharded(Galactica):
...
@@ -285,7 +287,7 @@ class GalacticaSharded(Galactica):
f
"Name
{
name
}
-- Current
{
current_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
f
"Name
{
name
}
-- Current
{
current_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
)
tensor
=
tensor
.
contiguous
()
tensor
=
tensor
.
contiguous
()
.
to
(
dtype
)
if
quantize
:
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
if
not
HAS_BITS_AND_BYTES
:
...
...
server/text_generation_server/models/gpt_neox.py
View file @
880a76ee
...
@@ -64,10 +64,11 @@ class GPTNeoxSharded(CausalLM):
...
@@ -64,10 +64,11 @@ class GPTNeoxSharded(CausalLM):
filenames
,
filenames
,
quantize
=
quantize
,
quantize
=
quantize
,
device
=
device
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
self
.
model
=
model
.
eval
()
.
to
(
dtype
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -80,6 +81,7 @@ class GPTNeoxSharded(CausalLM):
...
@@ -80,6 +81,7 @@ class GPTNeoxSharded(CausalLM):
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
world_size
:
int
,
world_size
:
int
,
):
):
...
@@ -140,7 +142,7 @@ class GPTNeoxSharded(CausalLM):
...
@@ -140,7 +142,7 @@ class GPTNeoxSharded(CausalLM):
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
)
tensor
=
tensor
.
contiguous
()
tensor
=
tensor
.
contiguous
()
.
to
(
dtype
)
if
quantize
:
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
if
not
HAS_BITS_AND_BYTES
:
...
...
server/text_generation_server/models/opt.py
View file @
880a76ee
...
@@ -80,10 +80,11 @@ class OPTSharded(OPT):
...
@@ -80,10 +80,11 @@ class OPTSharded(OPT):
filenames
,
filenames
,
quantize
=
quantize
,
quantize
=
quantize
,
device
=
device
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
self
.
model
=
model
.
eval
()
.
to
(
dtype
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -96,6 +97,7 @@ class OPTSharded(OPT):
...
@@ -96,6 +97,7 @@ class OPTSharded(OPT):
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
world_size
:
int
,
world_size
:
int
,
):
):
...
@@ -146,7 +148,7 @@ class OPTSharded(OPT):
...
@@ -146,7 +148,7 @@ class OPTSharded(OPT):
f
"Name
{
name
}
-- Current
{
current_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
f
"Name
{
name
}
-- Current
{
current_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
)
tensor
=
tensor
.
contiguous
()
tensor
=
tensor
.
contiguous
()
.
to
(
dtype
)
if
quantize
:
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
if
not
HAS_BITS_AND_BYTES
:
...
...
server/text_generation_server/models/t5.py
View file @
880a76ee
...
@@ -64,10 +64,11 @@ class T5Sharded(Seq2SeqLM):
...
@@ -64,10 +64,11 @@ class T5Sharded(Seq2SeqLM):
filenames
,
filenames
,
quantize
=
quantize
,
quantize
=
quantize
,
device
=
device
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
self
.
model
=
model
.
eval
()
.
to
(
dtype
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
Seq2SeqLM
,
self
).
__init__
(
super
(
Seq2SeqLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -80,6 +81,7 @@ class T5Sharded(Seq2SeqLM):
...
@@ -80,6 +81,7 @@ class T5Sharded(Seq2SeqLM):
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
world_size
:
int
,
world_size
:
int
,
):
):
...
@@ -146,7 +148,7 @@ class T5Sharded(Seq2SeqLM):
...
@@ -146,7 +148,7 @@ class T5Sharded(Seq2SeqLM):
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
)
tensor
=
tensor
.
contiguous
()
tensor
=
tensor
.
contiguous
()
.
to
(
dtype
)
if
quantize
:
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
if
not
HAS_BITS_AND_BYTES
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment