Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
aa9af07c
"vscode:/vscode.git/clone" did not exist on "c0911e31d367eca9600541bd634564c73753abbf"
Unverified
Commit
aa9af07c
authored
Oct 30, 2023
by
Woosuk Kwon
Committed by
GitHub
Oct 29, 2023
Browse files
Fix bias in InternLM (#1501)
parent
69be658b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
2 deletions
+4
-2
vllm/model_executor/models/internlm.py
vllm/model_executor/models/internlm.py
+4
-2
No files found.
vllm/model_executor/models/internlm.py
View file @
aa9af07c
...
@@ -62,6 +62,7 @@ class InternLMAttention(nn.Module):
...
@@ -62,6 +62,7 @@ class InternLMAttention(nn.Module):
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
num_heads
:
int
,
num_heads
:
int
,
bias
:
bool
,
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
max_position_embeddings
:
int
=
8192
,
max_position_embeddings
:
int
=
8192
,
):
):
...
@@ -81,13 +82,13 @@ class InternLMAttention(nn.Module):
...
@@ -81,13 +82,13 @@ class InternLMAttention(nn.Module):
self
.
qkv_proj
=
ColumnParallelLinear
(
self
.
qkv_proj
=
ColumnParallelLinear
(
hidden_size
,
hidden_size
,
3
*
self
.
total_num_heads
*
self
.
head_dim
,
3
*
self
.
total_num_heads
*
self
.
head_dim
,
bias
=
True
,
bias
=
bias
,
gather_output
=
False
,
gather_output
=
False
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
True
,
bias
=
bias
,
input_is_parallel
=
True
,
input_is_parallel
=
True
,
)
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
attn
=
PagedAttentionWithRoPE
(
...
@@ -126,6 +127,7 @@ class InternLMDecoderLayer(nn.Module):
...
@@ -126,6 +127,7 @@ class InternLMDecoderLayer(nn.Module):
self
.
self_attn
=
InternLMAttention
(
self
.
self_attn
=
InternLMAttention
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
bias
=
config
.
bias
,
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
max_position_embeddings
=
max_position_embeddings
,
max_position_embeddings
=
max_position_embeddings
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment