Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
880a76ee
You need to sign in or sign up before continuing.
Unverified
Commit
880a76ee
authored
Apr 12, 2023
by
OlivierDehaene
Committed by
GitHub
Apr 12, 2023
Browse files
feat(server): support sharded santacoder (#167)
parent
5fa8ae04
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
463 additions
and
49 deletions
+463
-49
server/text_generation_server/models/__init__.py
server/text_generation_server/models/__init__.py
+11
-3
server/text_generation_server/models/bloom.py
server/text_generation_server/models/bloom.py
+4
-2
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+1
-1
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+6
-1
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+198
-13
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+6
-8
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+221
-13
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+4
-2
server/text_generation_server/models/gpt_neox.py
server/text_generation_server/models/gpt_neox.py
+4
-2
server/text_generation_server/models/opt.py
server/text_generation_server/models/opt.py
+4
-2
server/text_generation_server/models/t5.py
server/text_generation_server/models/t5.py
+4
-2
No files found.
server/text_generation_server/models/__init__.py
View file @
880a76ee
...
@@ -18,8 +18,11 @@ from text_generation_server.models.t5 import T5Sharded
...
@@ -18,8 +18,11 @@ from text_generation_server.models.t5 import T5Sharded
try
:
try
:
from
text_generation_server.models.flash_neox
import
FlashNeoX
,
FlashNeoXSharded
from
text_generation_server.models.flash_neox
import
FlashNeoX
,
FlashNeoXSharded
from
text_generation_server.models.flash_santacoder
import
FlashSantacoder
from
text_generation_server.models.flash_llama
import
FlashLlama
,
FlashLlamaSharded
from
text_generation_server.models.flash_llama
import
FlashLlama
,
FlashLlamaSharded
from
text_generation_server.models.flash_santacoder
import
(
FlashSantacoder
,
FlashSantacoderSharded
,
)
FLASH_ATTENTION
=
torch
.
cuda
.
is_available
()
FLASH_ATTENTION
=
torch
.
cuda
.
is_available
()
except
ImportError
:
except
ImportError
:
...
@@ -49,6 +52,7 @@ if FLASH_ATTENTION:
...
@@ -49,6 +52,7 @@ if FLASH_ATTENTION:
__all__
.
append
(
FlashNeoX
)
__all__
.
append
(
FlashNeoX
)
__all__
.
append
(
FlashNeoXSharded
)
__all__
.
append
(
FlashNeoXSharded
)
__all__
.
append
(
FlashSantacoder
)
__all__
.
append
(
FlashSantacoder
)
__all__
.
append
(
FlashSantacoderSharded
)
__all__
.
append
(
FlashLlama
)
__all__
.
append
(
FlashLlama
)
__all__
.
append
(
FlashLlamaSharded
)
__all__
.
append
(
FlashLlamaSharded
)
...
@@ -78,9 +82,13 @@ def get_model(
...
@@ -78,9 +82,13 @@ def get_model(
else
:
else
:
return
Galactica
(
model_id
,
revision
,
quantize
=
quantize
)
return
Galactica
(
model_id
,
revision
,
quantize
=
quantize
)
if
"
santa
code
r
"
in
model_id
:
if
"
big
code"
in
model_id
:
if
sharded
:
if
sharded
:
raise
NotImplementedError
(
"sharded is not supported for Santacoder"
)
if
not
FLASH_ATTENTION
:
raise
NotImplementedError
(
FLASH_ATT_ERROR_MESSAGE
.
format
(
f
"Sharded Santacoder"
)
)
return
FlashSantacoderSharded
(
model_id
,
revision
=
revision
)
else
:
else
:
santacoder_cls
=
FlashSantacoder
if
FLASH_ATTENTION
else
SantaCoder
santacoder_cls
=
FlashSantacoder
if
FLASH_ATTENTION
else
SantaCoder
return
santacoder_cls
(
model_id
,
revision
,
quantize
)
return
santacoder_cls
(
model_id
,
revision
,
quantize
)
...
...
server/text_generation_server/models/bloom.py
View file @
880a76ee
...
@@ -93,10 +93,11 @@ class BLOOMSharded(BLOOM):
...
@@ -93,10 +93,11 @@ class BLOOMSharded(BLOOM):
filenames
,
filenames
,
quantize
=
quantize
,
quantize
=
quantize
,
device
=
device
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
self
.
model
=
model
.
eval
()
.
to
(
dtype
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
...
@@ -108,6 +109,7 @@ class BLOOMSharded(BLOOM):
...
@@ -108,6 +109,7 @@ class BLOOMSharded(BLOOM):
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
world_size
:
int
,
world_size
:
int
,
):
):
...
@@ -157,7 +159,7 @@ class BLOOMSharded(BLOOM):
...
@@ -157,7 +159,7 @@ class BLOOMSharded(BLOOM):
f
"Name
{
name
}
-- Current
{
current_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
f
"Name
{
name
}
-- Current
{
current_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
)
tensor
=
tensor
.
contiguous
()
tensor
=
tensor
.
contiguous
()
.
to
(
dtype
)
if
quantize
:
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
if
not
HAS_BITS_AND_BYTES
:
...
...
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
880a76ee
...
@@ -373,7 +373,7 @@ class LlamaMLP(nn.Module):
...
@@ -373,7 +373,7 @@ class LlamaMLP(nn.Module):
x
,
x
,
approximate
=
"tanh"
approximate
=
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
N
one
,
else
"n
one
"
,
)
)
)
)
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
880a76ee
...
@@ -376,7 +376,12 @@ class FlashMLP(nn.Module):
...
@@ -376,7 +376,12 @@ class FlashMLP(nn.Module):
self
.
act
=
(
self
.
act
=
(
ACT2FN
[
act
]
ACT2FN
[
act
]
if
"gelu"
not
in
act
if
"gelu"
not
in
act
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
)
else
lambda
x
:
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
"none"
,
)
)
)
if
process_group
is
None
:
if
process_group
is
None
:
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
880a76ee
import
torch
import
torch
import
torch.distributed
import
torch.distributed
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
...
@@ -65,6 +67,127 @@ class FastLinear(nn.Linear):
...
@@ -65,6 +67,127 @@ class FastLinear(nn.Linear):
return
torch
.
matmul
(
input
,
self
.
weight
)
return
torch
.
matmul
(
input
,
self
.
weight
)
class
TensorParallelColumnLinear
(
FastLinear
):
def
__init__
(
self
,
in_features
,
out_features
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
bias
=
True
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_world_size
=
process_group
.
size
()
assert
out_features
%
self
.
tp_world_size
==
0
out_features
=
out_features
//
self
.
tp_world_size
super
().
__init__
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
class
TensorParallelRowLinear
(
FastLinear
):
def
__init__
(
self
,
in_features
,
out_features
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
reduce
=
True
,
bias
=
True
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_world_size
=
process_group
.
size
()
self
.
reduce
=
reduce
assert
in_features
%
self
.
tp_world_size
==
0
in_features
=
in_features
//
self
.
tp_world_size
super
().
__init__
(
in_features
=
in_features
,
out_features
=
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
,
)
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
out
=
super
(
TensorParallelRowLinear
,
self
).
forward
(
input
)
if
self
.
reduce
:
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
class
TensorParallelEmbedding
(
nn
.
Embedding
):
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
process_group
:
torch
.
distributed
.
ProcessGroup
,
reduce
=
True
,
padding_idx
=
None
,
max_norm
=
None
,
norm_type
=
2.0
,
scale_grad_by_freq
=
False
,
sparse
=
False
,
_weight
=
None
,
device
=
None
,
dtype
=
None
,
):
self
.
process_group
=
process_group
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
self
.
reduce
=
reduce
self
.
original_num_embeddings
=
num_embeddings
assert
num_embeddings
%
self
.
tp_world_size
==
0
block_size
=
num_embeddings
//
self
.
tp_world_size
# inputs in `[min_id, max_id[` are handled by `self` to get embeddings
self
.
min_id
=
self
.
tp_rank
*
block_size
self
.
max_id
=
(
self
.
tp_rank
+
1
)
*
block_size
# Additional entry that will map to zero
# Used for masking
self
.
null_idx
=
block_size
super
().
__init__
(
block_size
,
embedding_dim
,
padding_idx
=
padding_idx
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
,
_weight
=
_weight
,
device
=
device
,
dtype
=
dtype
,
)
def
add_null_idx
(
self
):
"""Additional 0 entry used for masking"""
self
.
weight
=
nn
.
Parameter
(
F
.
pad
(
self
.
weight
,
(
0
,
0
,
0
,
1
)))
def
forward
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
# translate for [0, self.max_id - self.min_id[
input
=
torch
.
where
(
(
self
.
min_id
>
input
)
|
(
input
>=
self
.
max_id
),
self
.
null_idx
,
input
-
self
.
min_id
,
)
out
=
super
().
forward
(
input
)
if
self
.
reduce
:
torch
.
distributed
.
all_reduce
(
out
,
group
=
self
.
process_group
)
return
out
class
FlashMQAttention
(
torch
.
nn
.
Module
):
class
FlashMQAttention
(
torch
.
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -80,10 +203,16 @@ class FlashMQAttention(torch.nn.Module):
...
@@ -80,10 +203,16 @@ class FlashMQAttention(torch.nn.Module):
self
.
softmax_scale
=
self
.
head_size
**
(
-
0.5
)
self
.
softmax_scale
=
self
.
head_size
**
(
-
0.5
)
if
process_group
is
None
:
if
process_group
is
None
:
self
.
attn
=
FastLinear
(
hidden_size
,
hidden_size
+
2
*
self
.
head_size
)
self
.
c_
attn
=
FastLinear
(
hidden_size
,
hidden_size
+
2
*
self
.
head_size
)
self
.
c_proj
=
FastLinear
(
hidden_size
,
hidden_size
)
self
.
c_proj
=
FastLinear
(
hidden_size
,
hidden_size
)
else
:
else
:
raise
NotImplementedError
self
.
num_heads
=
self
.
num_heads
//
process_group
.
size
()
self
.
c_attn
=
FastLinear
(
hidden_size
,
self
.
head_size
*
(
self
.
num_heads
+
2
))
self
.
c_proj
=
TensorParallelRowLinear
(
hidden_size
,
hidden_size
,
process_group
=
process_group
,
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -94,10 +223,12 @@ class FlashMQAttention(torch.nn.Module):
...
@@ -94,10 +223,12 @@ class FlashMQAttention(torch.nn.Module):
layer_past_present_indices
,
layer_past_present_indices
,
cu_seqlens_q
,
cu_seqlens_q
,
):
):
qkv
=
self
.
attn
(
hidden_states
)
qkv
=
self
.
c_
attn
(
hidden_states
)
# Split query from key_value
# Split query from key_value
query
,
key_value
=
qkv
.
split
([
self
.
hidden_size
,
2
*
self
.
head_size
],
dim
=
1
)
query
,
key_value
=
qkv
.
split
(
[
self
.
head_size
*
self
.
num_heads
,
2
*
self
.
head_size
],
dim
=
1
)
# Prepare query and key_value for indexing
# Prepare query and key_value for indexing
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
...
@@ -171,7 +302,7 @@ class MLP(nn.Module):
...
@@ -171,7 +302,7 @@ class MLP(nn.Module):
x
,
x
,
approximate
=
"tanh"
approximate
=
"tanh"
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
if
act
in
[
"gelu_fast"
,
"gelu_pytorch_tanh"
]
else
N
one
,
else
"n
one
"
,
)
)
)
)
...
@@ -179,7 +310,16 @@ class MLP(nn.Module):
...
@@ -179,7 +310,16 @@ class MLP(nn.Module):
self
.
c_fc
=
FastLinear
(
hidden_size
,
intermediate_size
)
self
.
c_fc
=
FastLinear
(
hidden_size
,
intermediate_size
)
self
.
c_proj
=
FastLinear
(
intermediate_size
,
hidden_size
)
self
.
c_proj
=
FastLinear
(
intermediate_size
,
hidden_size
)
else
:
else
:
raise
NotImplementedError
self
.
c_fc
=
TensorParallelColumnLinear
(
hidden_size
,
intermediate_size
,
process_group
=
process_group
,
)
self
.
c_proj
=
TensorParallelRowLinear
(
intermediate_size
,
hidden_size
,
process_group
=
process_group
,
)
def
forward
(
self
,
hidden_states
):
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
self
.
c_fc
(
hidden_states
)
...
@@ -246,11 +386,30 @@ class FlashSantacoderModel(nn.Module):
...
@@ -246,11 +386,30 @@ class FlashSantacoderModel(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
self
.
process_group
=
process_group
self
.
tp_embeddings
=
False
if
process_group
is
not
None
:
if
process_group
is
not
None
:
raise
NotImplementedError
self
.
tp_rank
=
process_group
.
rank
()
self
.
tp_world_size
=
process_group
.
size
()
self
.
wte
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
)
if
config
.
vocab_size
%
self
.
tp_world_size
==
0
:
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
self
.
tp_embeddings
=
True
if
self
.
tp_embeddings
:
self
.
wte
=
TensorParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
reduce
=
False
,
process_group
=
process_group
,
)
self
.
wpe
=
TensorParallelEmbedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
,
reduce
=
False
,
process_group
=
process_group
,
)
else
:
self
.
wte
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
wpe
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
config
.
hidden_size
)
self
.
h
=
nn
.
ModuleList
(
self
.
h
=
nn
.
ModuleList
(
[
[
...
@@ -273,9 +432,12 @@ class FlashSantacoderModel(nn.Module):
...
@@ -273,9 +432,12 @@ class FlashSantacoderModel(nn.Module):
self
.
num_heads
=
self
.
h
[
0
].
attn
.
num_heads
self
.
num_heads
=
self
.
h
[
0
].
attn
.
num_heads
def
post_load_weights
(
self
):
def
post_load_weights
(
self
):
if
self
.
tp_embeddings
:
self
.
wte
.
add_null_idx
()
self
.
wpe
.
add_null_idx
()
for
layer
in
self
.
h
:
for
layer
in
self
.
h
:
layer
:
Block
layer
:
Block
layer
.
attn
.
attn
.
transpose_weight
()
layer
.
attn
.
c_
attn
.
transpose_weight
()
layer
.
attn
.
c_proj
.
transpose_weight
()
layer
.
attn
.
c_proj
.
transpose_weight
()
layer
.
mlp
.
c_fc
.
transpose_weight
()
layer
.
mlp
.
c_fc
.
transpose_weight
()
layer
.
mlp
.
c_proj
.
transpose_weight
()
layer
.
mlp
.
c_proj
.
transpose_weight
()
...
@@ -289,6 +451,8 @@ class FlashSantacoderModel(nn.Module):
...
@@ -289,6 +451,8 @@ class FlashSantacoderModel(nn.Module):
past_key_values
=
None
,
past_key_values
=
None
,
):
):
hidden_states
=
self
.
wte
(
input_ids
)
+
self
.
wpe
(
position_ids
)
hidden_states
=
self
.
wte
(
input_ids
)
+
self
.
wpe
(
position_ids
)
if
self
.
tp_embeddings
:
torch
.
distributed
.
all_reduce
(
hidden_states
,
group
=
self
.
process_group
)
# Prefill
# Prefill
if
past_key_values
is
None
:
if
past_key_values
is
None
:
...
@@ -335,7 +499,14 @@ class FlashSantacoderForCausalLM(nn.Module):
...
@@ -335,7 +499,14 @@ class FlashSantacoderForCausalLM(nn.Module):
self
.
transformer
=
FlashSantacoderModel
(
config
,
process_group
)
self
.
transformer
=
FlashSantacoderModel
(
config
,
process_group
)
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
if
self
.
transformer
.
tp_embeddings
:
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
//
process_group
.
size
(),
bias
=
False
,
)
else
:
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
def
post_load_weights
(
self
):
def
post_load_weights
(
self
):
self
.
transformer
.
post_load_weights
()
self
.
transformer
.
post_load_weights
()
...
@@ -352,4 +523,18 @@ class FlashSantacoderForCausalLM(nn.Module):
...
@@ -352,4 +523,18 @@ class FlashSantacoderForCausalLM(nn.Module):
hidden_states
,
present
=
self
.
transformer
(
hidden_states
,
present
=
self
.
transformer
(
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
input_ids
,
position_ids
,
cu_seqlens
,
max_s
,
past_key_values
)
)
return
self
.
lm_head
(
hidden_states
),
present
logits
=
self
.
lm_head
(
hidden_states
)
if
self
.
transformer
.
tp_embeddings
:
# Logits are sharded, so we need to gather them
world_logits
=
[
torch
.
empty_like
(
logits
)
for
_
in
range
(
self
.
transformer
.
tp_world_size
)
]
torch
.
distributed
.
all_gather
(
world_logits
,
logits
,
group
=
self
.
transformer
.
process_group
)
world_logits
=
torch
.
cat
(
world_logits
,
dim
=
1
)
return
world_logits
,
present
return
logits
,
present
server/text_generation_server/models/flash_neox.py
View file @
880a76ee
...
@@ -5,7 +5,7 @@ from accelerate import init_empty_weights
...
@@ -5,7 +5,7 @@ from accelerate import init_empty_weights
from
opentelemetry
import
trace
from
opentelemetry
import
trace
from
safetensors
import
safe_open
from
safetensors
import
safe_open
from
transformers
import
AutoTokenizer
,
AutoConfig
from
transformers
import
AutoTokenizer
,
AutoConfig
from
typing
import
Optional
,
Tuple
,
List
from
typing
import
Optional
,
List
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_neox_modeling
import
(
from
text_generation_server.models.custom_modeling.flash_neox_modeling
import
(
...
@@ -63,13 +63,13 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -63,13 +63,13 @@ class FlashNeoXSharded(FlashNeoX):
self
.
load_weights
(
self
.
load_weights
(
model
,
model
,
filenames
,
filenames
,
quantize
=
quantize
,
device
=
device
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
model
.
post_load_weights
()
model
.
post_load_weights
()
self
.
model
=
model
.
eval
()
.
to
(
dtype
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -80,16 +80,14 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -80,16 +80,14 @@ class FlashNeoXSharded(FlashNeoX):
def
load_weights
(
def
load_weights
(
model
,
model
,
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
world_size
:
int
,
world_size
:
int
,
):
):
parameters
=
dict
(
model
.
named_parameters
())
parameters
=
dict
(
model
.
named_parameters
())
for
file
in
filenames
:
for
file
in
filenames
:
with
safe_open
(
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
))
as
f
:
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
)
as
f
:
for
name
in
f
.
keys
():
for
name
in
f
.
keys
():
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
module
=
model
.
get_submodule
(
module_name
)
module
=
model
.
get_submodule
(
module_name
)
...
@@ -142,7 +140,7 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -142,7 +140,7 @@ class FlashNeoXSharded(FlashNeoX):
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
)
tensor
=
tensor
.
contiguous
()
tensor
=
tensor
.
contiguous
()
.
to
(
dtype
)
if
current_parameter_tensor
is
not
None
:
if
current_parameter_tensor
is
not
None
:
module
.
_parameters
[
param_name
]
=
tensor
module
.
_parameters
[
param_name
]
=
tensor
...
...
server/text_generation_server/models/flash_santacoder.py
View file @
880a76ee
...
@@ -3,15 +3,20 @@ import torch.distributed
...
@@ -3,15 +3,20 @@ import torch.distributed
from
accelerate
import
init_empty_weights
from
accelerate
import
init_empty_weights
from
opentelemetry
import
trace
from
opentelemetry
import
trace
from
safetensors
import
safe_open
from
pathlib
import
Path
from
pathlib
import
Path
from
transformers
import
AutoTokenizer
,
Auto
Config
from
transformers
import
AutoTokenizer
,
GPT2
Config
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_santacoder_modeling
import
(
from
text_generation_server.models.custom_modeling.flash_santacoder_modeling
import
(
FlashSantacoderForCausalLM
,
FlashSantacoderForCausalLM
,
TensorParallelRowLinear
,
TensorParallelColumnLinear
,
TensorParallelEmbedding
,
)
)
from
text_generation_server.utils
import
(
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
weight_files
,
weight_files
,
download_weights
,
download_weights
,
weight_hub_files
,
weight_hub_files
,
...
@@ -36,10 +41,9 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -36,10 +41,9 @@ class FlashSantacoder(FlashCausalLM):
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
)
config
=
Auto
Config
.
from_pretrained
(
config
=
GPT2
Config
.
from_pretrained
(
model_id
,
model_id
,
revision
=
revision
,
revision
=
revision
,
trust_remote_code
=
True
,
# Needed as the config is not part of Transformers
)
)
# We do not use from_pretrained as we modified the model internal module layout
# We do not use from_pretrained as we modified the model internal module layout
...
@@ -54,12 +58,9 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -54,12 +58,9 @@ class FlashSantacoder(FlashCausalLM):
model
=
FlashSantacoderForCausalLM
(
config
)
model
=
FlashSantacoderForCausalLM
(
config
)
self
.
load_weights
(
self
.
load_weights
(
model
,
model
,
filenames
,
device
,
dtype
,
config
.
architectures
[
0
].
startswith
(
"GPT2"
)
filenames
,
device
,
dtype
,
)
)
self
.
model
=
model
.
eval
()
.
to
(
device
).
to
(
dtype
)
self
.
model
=
model
.
eval
()
super
(
FlashCausalLM
,
self
).
__init__
(
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
tokenizer
=
tokenizer
,
device
=
device
,
decode_buffer
=
1
...
@@ -71,6 +72,7 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -71,6 +72,7 @@ class FlashSantacoder(FlashCausalLM):
filenames
:
List
[
Path
],
filenames
:
List
[
Path
],
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
transpose
:
bool
,
):
):
for
filename
in
filenames
:
for
filename
in
filenames
:
state_dict
=
torch
.
load
(
filename
,
map_location
=
"cpu"
)
state_dict
=
torch
.
load
(
filename
,
map_location
=
"cpu"
)
...
@@ -81,9 +83,9 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -81,9 +83,9 @@ class FlashSantacoder(FlashCausalLM):
# Fused qkv
# Fused qkv
if
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
:
if
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
:
final_key
=
layer_name
+
".attn.weight"
final_key
=
layer_name
+
".
c_
attn.weight"
elif
"q_attn.bias"
in
key
or
"kv_attn.bias"
in
key
:
elif
"q_attn.bias"
in
key
or
"kv_attn.bias"
in
key
:
final_key
=
layer_name
+
".attn.bias"
final_key
=
layer_name
+
".
c_
attn.bias"
else
:
else
:
final_key
=
key
final_key
=
key
...
@@ -97,18 +99,19 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -97,18 +99,19 @@ class FlashSantacoder(FlashCausalLM):
current_parameter_tensor
=
None
current_parameter_tensor
=
None
if
current_parameter_tensor
is
not
None
:
if
current_parameter_tensor
is
not
None
:
if
(
if
transpose
and
(
"c_fc.weight"
in
key
"c_fc.weight"
in
key
or
"c_proj.weight"
in
key
or
"c_proj.weight"
in
key
or
"q_attn.weight"
in
key
or
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
or
"kv_attn.weight"
in
key
or
"c_attn.weight"
in
key
):
):
# Tranpose as we use nn.Linear instead of Conv1D
# Tranpose as we use nn.Linear instead of Conv1D
value
=
value
.
T
value
=
value
.
T
if
current_parameter_tensor
.
device
==
torch
.
device
(
"meta"
):
if
current_parameter_tensor
.
device
==
torch
.
device
(
"meta"
):
# Init qkv
# Init qkv
if
"attn.weight"
in
final_key
:
if
"
c_
attn.weight"
in
final_key
:
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
(
(
model
.
transformer
.
head_size
model
.
transformer
.
head_size
...
@@ -116,7 +119,7 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -116,7 +119,7 @@ class FlashSantacoder(FlashCausalLM):
value
.
shape
[
1
],
value
.
shape
[
1
],
)
)
)
)
elif
"attn.bias"
in
final_key
:
elif
"
c_
attn.bias"
in
final_key
:
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
module
.
_parameters
[
param_name
]
=
value
.
new_empty
(
(
(
model
.
transformer
.
head_size
model
.
transformer
.
head_size
...
@@ -156,3 +159,208 @@ class FlashSantacoder(FlashCausalLM):
...
@@ -156,3 +159,208 @@ class FlashSantacoder(FlashCausalLM):
return
self
.
tokenizer
.
decode
(
return
self
.
tokenizer
.
decode
(
generated_ids
,
skip_special_tokens
=
False
,
cleanup_tokenization_spaces
=
False
generated_ids
,
skip_special_tokens
=
False
,
cleanup_tokenization_spaces
=
False
)
)
class
FlashSantacoderSharded
(
FlashSantacoder
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
):
self
.
process_group
,
self
.
rank
,
self
.
world_size
=
initialize_torch_distributed
()
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashSantacoderSharded is only available on GPU"
)
if
quantize
:
raise
NotImplementedError
(
"FlashSantacoderSharded does not support quantization"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_id
,
revision
=
revision
,
padding_side
=
"left"
,
truncation_side
=
"left"
)
config
=
GPT2Config
.
from_pretrained
(
model_id
,
revision
=
revision
,
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
with
init_empty_weights
():
model
=
FlashSantacoderForCausalLM
(
config
,
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
self
.
load_weights
(
model
,
filenames
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
transpose
=
config
.
architectures
[
0
].
startswith
(
"GPT2"
),
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
FlashCausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
device
=
device
,
)
@
staticmethod
def
load_weights
(
model
,
filenames
:
List
[
str
],
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
world_size
:
int
,
transpose
:
bool
,
):
for
file
in
filenames
:
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
))
as
f
:
for
key
in
f
.
keys
():
slice_
=
f
.
get_slice
(
key
)
layer_name
=
"."
.
join
(
key
.
split
(
"."
)[:
4
])
# Fused qkv
if
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
:
final_key
=
layer_name
+
".c_attn.weight"
elif
"q_attn.bias"
in
key
or
"kv_attn.bias"
in
key
:
final_key
=
layer_name
+
".c_attn.bias"
else
:
final_key
=
key
module_name
,
param_name
=
final_key
.
rsplit
(
"."
,
1
)
module
=
model
.
get_submodule
(
module_name
)
if
isinstance
(
module
,
TensorParallelColumnLinear
):
dim
=
1
if
transpose
and
"weight"
in
param_name
else
0
size
=
slice_
.
get_shape
()[
dim
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
(
slice_
[
start
:
stop
]
if
dim
==
0
else
slice_
[:,
start
:
stop
]
)
elif
isinstance
(
module
,
TensorParallelRowLinear
):
if
param_name
==
"weight"
:
dim
=
0
if
transpose
else
1
size
=
slice_
.
get_shape
()[
dim
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
(
slice_
[
start
:
stop
]
if
dim
==
0
else
slice_
[:,
start
:
stop
]
)
else
:
tensor
=
slice_
[:]
# XXX: Hack for Rowlinear to add the bias only once.
if
rank
!=
0
:
tensor
=
torch
.
zeros_like
(
tensor
)
elif
isinstance
(
module
,
TensorParallelEmbedding
):
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
elif
key
==
"lm_head.weight"
and
model
.
transformer
.
tp_embeddings
:
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
slice_
[
start
:
stop
]
else
:
try
:
tensor
=
slice_
[:]
except
:
tensor
=
f
.
get_tensor
(
key
)
tensor
=
tensor
.
contiguous
().
to
(
dtype
)
try
:
current_parameter_tensor
=
module
.
_parameters
[
param_name
]
except
KeyError
:
current_parameter_tensor
=
None
if
current_parameter_tensor
is
not
None
:
if
transpose
and
(
"c_fc.weight"
in
key
or
"c_proj.weight"
in
key
or
"q_attn.weight"
in
key
or
"kv_attn.weight"
in
key
or
"c_attn.weight"
in
key
):
# Tranpose as we use nn.Linear instead of Conv1D
tensor
=
tensor
.
T
if
current_parameter_tensor
.
device
==
torch
.
device
(
"meta"
):
# Init qkv
if
"c_attn.weight"
in
final_key
:
module
.
_parameters
[
param_name
]
=
tensor
.
new_empty
(
(
model
.
transformer
.
head_size
*
(
model
.
transformer
.
num_heads
+
2
),
tensor
.
shape
[
1
],
)
)
elif
"c_attn.bias"
in
final_key
:
module
.
_parameters
[
param_name
]
=
tensor
.
new_empty
(
(
model
.
transformer
.
head_size
*
(
model
.
transformer
.
num_heads
+
2
)
)
)
# Copy to correct slice
if
"q_attn"
in
key
:
size
=
tensor
.
shape
[
0
]
block_size
=
size
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
tensor
=
tensor
[
start
:
stop
]
module
.
_parameters
[
param_name
][:
tensor
.
shape
[
0
]]
=
tensor
elif
"kv_attn.weight"
in
key
:
module
.
_parameters
[
param_name
][
model
.
transformer
.
head_size
*
model
.
transformer
.
num_heads
:
]
=
tensor
elif
"kv_attn.bias"
in
key
:
module
.
_parameters
[
param_name
][
model
.
transformer
.
head_size
*
model
.
transformer
.
num_heads
:
]
=
tensor
elif
"c_attn"
in
key
:
# Slice q_tensor by shard
q_tensor
=
tensor
[:
-
2
*
model
.
transformer
.
head_size
]
block_size
=
q_tensor
.
shape
[
0
]
//
world_size
start
=
rank
*
block_size
stop
=
(
rank
+
1
)
*
block_size
q_tensor
=
q_tensor
[
start
:
stop
]
module
.
_parameters
[
param_name
][
:
q_tensor
.
shape
[
0
]
]
=
q_tensor
# Kv tensor is copied for every shard
kv_tensor
=
tensor
[
-
2
*
model
.
transformer
.
head_size
:]
module
.
_parameters
[
param_name
][
q_tensor
.
shape
[
0
]
:
]
=
kv_tensor
else
:
if
current_parameter_tensor
.
shape
!=
tensor
.
shape
:
raise
ValueError
(
f
"Name
{
key
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
module
.
_parameters
[
param_name
]
=
tensor
else
:
module
.
_buffers
[
param_name
]
=
tensor
torch
.
cuda
.
empty_cache
()
model
.
post_load_weights
()
server/text_generation_server/models/galactica.py
View file @
880a76ee
...
@@ -219,10 +219,11 @@ class GalacticaSharded(Galactica):
...
@@ -219,10 +219,11 @@ class GalacticaSharded(Galactica):
filenames
,
filenames
,
quantize
=
quantize
,
quantize
=
quantize
,
device
=
device
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
self
.
model
=
model
.
eval
()
.
to
(
dtype
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -235,6 +236,7 @@ class GalacticaSharded(Galactica):
...
@@ -235,6 +236,7 @@ class GalacticaSharded(Galactica):
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
world_size
:
int
,
world_size
:
int
,
):
):
...
@@ -285,7 +287,7 @@ class GalacticaSharded(Galactica):
...
@@ -285,7 +287,7 @@ class GalacticaSharded(Galactica):
f
"Name
{
name
}
-- Current
{
current_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
f
"Name
{
name
}
-- Current
{
current_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
)
tensor
=
tensor
.
contiguous
()
tensor
=
tensor
.
contiguous
()
.
to
(
dtype
)
if
quantize
:
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
if
not
HAS_BITS_AND_BYTES
:
...
...
server/text_generation_server/models/gpt_neox.py
View file @
880a76ee
...
@@ -64,10 +64,11 @@ class GPTNeoxSharded(CausalLM):
...
@@ -64,10 +64,11 @@ class GPTNeoxSharded(CausalLM):
filenames
,
filenames
,
quantize
=
quantize
,
quantize
=
quantize
,
device
=
device
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
self
.
model
=
model
.
eval
()
.
to
(
dtype
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -80,6 +81,7 @@ class GPTNeoxSharded(CausalLM):
...
@@ -80,6 +81,7 @@ class GPTNeoxSharded(CausalLM):
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
world_size
:
int
,
world_size
:
int
,
):
):
...
@@ -140,7 +142,7 @@ class GPTNeoxSharded(CausalLM):
...
@@ -140,7 +142,7 @@ class GPTNeoxSharded(CausalLM):
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
)
tensor
=
tensor
.
contiguous
()
tensor
=
tensor
.
contiguous
()
.
to
(
dtype
)
if
quantize
:
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
if
not
HAS_BITS_AND_BYTES
:
...
...
server/text_generation_server/models/opt.py
View file @
880a76ee
...
@@ -80,10 +80,11 @@ class OPTSharded(OPT):
...
@@ -80,10 +80,11 @@ class OPTSharded(OPT):
filenames
,
filenames
,
quantize
=
quantize
,
quantize
=
quantize
,
device
=
device
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
self
.
model
=
model
.
eval
()
.
to
(
dtype
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
CausalLM
,
self
).
__init__
(
super
(
CausalLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -96,6 +97,7 @@ class OPTSharded(OPT):
...
@@ -96,6 +97,7 @@ class OPTSharded(OPT):
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
world_size
:
int
,
world_size
:
int
,
):
):
...
@@ -146,7 +148,7 @@ class OPTSharded(OPT):
...
@@ -146,7 +148,7 @@ class OPTSharded(OPT):
f
"Name
{
name
}
-- Current
{
current_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
f
"Name
{
name
}
-- Current
{
current_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
)
tensor
=
tensor
.
contiguous
()
tensor
=
tensor
.
contiguous
()
.
to
(
dtype
)
if
quantize
:
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
if
not
HAS_BITS_AND_BYTES
:
...
...
server/text_generation_server/models/t5.py
View file @
880a76ee
...
@@ -64,10 +64,11 @@ class T5Sharded(Seq2SeqLM):
...
@@ -64,10 +64,11 @@ class T5Sharded(Seq2SeqLM):
filenames
,
filenames
,
quantize
=
quantize
,
quantize
=
quantize
,
device
=
device
,
device
=
device
,
dtype
=
dtype
,
rank
=
self
.
rank
,
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
world_size
=
self
.
world_size
,
)
)
self
.
model
=
model
.
eval
()
.
to
(
dtype
)
self
.
model
=
model
.
eval
()
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
torch
.
distributed
.
barrier
(
group
=
self
.
process_group
)
super
(
Seq2SeqLM
,
self
).
__init__
(
super
(
Seq2SeqLM
,
self
).
__init__
(
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
...
@@ -80,6 +81,7 @@ class T5Sharded(Seq2SeqLM):
...
@@ -80,6 +81,7 @@ class T5Sharded(Seq2SeqLM):
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
quantize
:
bool
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
world_size
:
int
,
world_size
:
int
,
):
):
...
@@ -146,7 +148,7 @@ class T5Sharded(Seq2SeqLM):
...
@@ -146,7 +148,7 @@ class T5Sharded(Seq2SeqLM):
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
)
tensor
=
tensor
.
contiguous
()
tensor
=
tensor
.
contiguous
()
.
to
(
dtype
)
if
quantize
:
if
quantize
:
if
not
HAS_BITS_AND_BYTES
:
if
not
HAS_BITS_AND_BYTES
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment