Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
eedac9db
Unverified
Commit
eedac9db
authored
Aug 23, 2023
by
Wen Sun
Committed by
GitHub
Aug 22, 2023
Browse files
fix: revert code to avoid no attribute problem (#827)
parent
14f9c72b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
3 deletions
+4
-3
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+4
-3
No files found.
vllm/model_executor/models/gpt_bigcode.py
View file @
eedac9db
...
@@ -259,6 +259,8 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -259,6 +259,8 @@ class GPTBigCodeForCausalLM(nn.Module):
model_name_or_path
:
str
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
use_np_cache
:
bool
=
False
):
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
state_dict
=
self
.
state_dict
()
state_dict
=
self
.
state_dict
()
...
@@ -288,8 +290,7 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -288,8 +290,7 @@ class GPTBigCodeForCausalLM(nn.Module):
hidden_size
=
self
.
config
.
hidden_size
hidden_size
=
self
.
config
.
hidden_size
head_size
=
hidden_size
//
total_num_heads
head_size
=
hidden_size
//
total_num_heads
total_kv_size
=
head_size
*
total_num_kv_heads
total_kv_size
=
head_size
*
total_num_kv_heads
num_heads
=
(
total_num_heads
//
num_heads
=
total_num_heads
//
tensor_model_parallel_world_size
self
.
tensor_model_parallel_world_size
)
head_start
=
tensor_model_parallel_rank
*
num_heads
head_start
=
tensor_model_parallel_rank
*
num_heads
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
head_end
=
(
tensor_model_parallel_rank
+
1
)
*
num_heads
...
@@ -329,7 +330,7 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -329,7 +330,7 @@ class GPTBigCodeForCausalLM(nn.Module):
if
name
==
"transformer.wte.weight"
:
if
name
==
"transformer.wte.weight"
:
# Consider padding in the vocab size.
# Consider padding in the vocab size.
padded_vocab_size
=
param
.
shape
[
padded_vocab_size
=
param
.
shape
[
0
]
*
self
.
tensor_model_parallel_world_size
0
]
*
tensor_model_parallel_world_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
num_extra_rows
=
padded_vocab_size
-
self
.
config
.
vocab_size
extra_rows
=
torch
.
empty
(
num_extra_rows
,
extra_rows
=
torch
.
empty
(
num_extra_rows
,
loaded_weight
.
shape
[
1
])
loaded_weight
.
shape
[
1
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment