Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a10d3056
Unverified
Commit
a10d3056
authored
Apr 11, 2024
by
Antoni Baum
Committed by
GitHub
Apr 11, 2024
Browse files
[Core] Set `linear_weights` directly on the layer (#3977)
parent
8afca508
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
114 additions
and
102 deletions
+114
-102
csrc/quantization/gptq/q_gemm.cu
csrc/quantization/gptq/q_gemm.cu
+1
-1
tests/kernels/test_moe.py
tests/kernels/test_moe.py
+1
-1
vllm/lora/layers.py
vllm/lora/layers.py
+4
-8
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+40
-37
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+16
-13
vllm/model_executor/layers/quantization/gptq.py
vllm/model_executor/layers/quantization/gptq.py
+26
-21
vllm/model_executor/layers/quantization/marlin.py
vllm/model_executor/layers/quantization/marlin.py
+13
-10
vllm/model_executor/layers/quantization/squeezellm.py
vllm/model_executor/layers/quantization/squeezellm.py
+13
-11
No files found.
csrc/quantization/gptq/q_gemm.cu
View file @
a10d3056
...
...
@@ -2067,7 +2067,7 @@ void gptq_shuffle
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
q_weight
));
vllm
::
gptq
::
shuffle_exllama_weight
(
(
uint32_t
*
)
q_weight
.
data_ptr
(),
q_perm
.
device
().
is_meta
()
?
NULL
:
(
int
*
)
q_perm
.
data_ptr
(),
q_perm
.
device
().
is_meta
()
||
q_perm
.
numel
()
==
0
?
NULL
:
(
int
*
)
q_perm
.
data_ptr
(),
q_weight
.
size
(
0
)
*
32
/
bit
,
q_weight
.
size
(
1
),
bit
...
...
tests/kernels/test_moe.py
View file @
a10d3056
...
...
@@ -73,7 +73,7 @@ def test_mixtral_moe(dtype: torch.dtype):
).
cuda
()
# Load the weights
vllm_moe
.
gate
.
linear_weights
[
"weight"
]
[:]
=
hf_moe
.
gate
.
weight
.
data
vllm_moe
.
gate
.
weight
.
data
[:]
=
hf_moe
.
gate
.
weight
.
data
for
i
in
range
(
config
.
num_local_experts
):
weights
=
(
hf_moe
.
experts
[
i
].
w1
.
weight
.
data
,
hf_moe
.
experts
[
i
].
w3
.
weight
.
data
)
...
...
vllm/lora/layers.py
View file @
a10d3056
...
...
@@ -368,7 +368,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
def
apply_weights
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
self
.
base_layer
.
linear_weights
,
x
,
bias
)
self
.
base_layer
,
x
,
bias
)
_apply_lora
(
x
,
self
.
lora_a_stacked
,
...
...
@@ -402,10 +402,6 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
if
self
.
base_layer
.
skip_bias_add
else
None
)
return
output
,
output_bias
@
property
def
linear_weights
(
self
):
return
self
.
base_layer
.
linear_weights
@
classmethod
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
...
...
@@ -505,7 +501,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def
apply_weights
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
self
.
base_layer
.
linear_weights
,
x
,
bias
)
self
.
base_layer
,
x
,
bias
)
_apply_lora_packed_nslice
(
x
,
self
.
lora_a_stacked
,
...
...
@@ -746,7 +742,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
def
apply_weights
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
self
.
base_layer
.
linear_weights
,
x
,
bias
)
self
.
base_layer
,
x
,
bias
)
_apply_lora_packed_nslice
(
x
,
self
.
lora_a_stacked
,
...
...
@@ -838,7 +834,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
def
apply_weights
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
self
.
base_layer
.
linear_weights
,
x
)
self
.
base_layer
,
x
)
_apply_lora
(
x
,
self
.
lora_a_stacked
,
...
...
vllm/model_executor/layers/linear.py
View file @
a10d3056
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch.nn.functional
as
F
...
...
@@ -28,19 +28,24 @@ class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods."""
@
abstractmethod
def
create_weights
(
self
,
input_size_per_partition
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]:
"""Create weights for a linear layer."""
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer."""
raise
NotImplementedError
@
abstractmethod
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
torch
.
Tensor
]
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""Apply the weights to the input tensor."""
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
...
...
@@ -55,22 +60,24 @@ class UnquantizedLinearMethod(LinearMethodBase):
def
__init__
(
self
,
separate_bias_add
:
bool
=
False
):
self
.
separate_bias_add
=
separate_bias_add
def
create_weights
(
self
,
input_size_per_partition
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]
:
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
)
:
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
return
{
"weight"
:
weight
}
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
torch
.
Tensor
]
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
weight
=
weights
[
"
weight
"
]
weight
=
layer
.
weight
if
self
.
separate_bias_add
:
if
bias
is
not
None
:
return
F
.
linear
(
x
,
weight
)
+
bias
...
...
@@ -111,12 +118,9 @@ class ReplicatedLinear(torch.nn.Module):
if
linear_method
is
None
:
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
=
linear_method
self
.
linear_weights
=
self
.
linear_method
.
create_weights
(
self
.
input_size
,
self
.
output_size
,
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
for
name
,
weight
in
self
.
linear_weights
.
items
():
if
isinstance
(
weight
,
torch
.
Tensor
):
self
.
register_parameter
(
name
,
weight
)
self
.
linear_method
.
create_weights
(
self
,
self
.
input_size
,
self
.
output_size
,
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
self
.
params_dtype
))
...
...
@@ -126,7 +130,7 @@ class ReplicatedLinear(torch.nn.Module):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output
=
self
.
linear_method
.
apply_weights
(
self
.
linear_weights
,
x
,
bias
)
output
=
self
.
linear_method
.
apply_weights
(
self
,
x
,
bias
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
...
...
@@ -177,13 +181,13 @@ class ColumnParallelLinear(torch.nn.Module):
if
linear_method
is
None
:
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
=
linear_method
self
.
linear_weights
=
self
.
linear_method
.
create_weights
(
self
.
input_size
,
self
.
output_size_per_partition
,
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
for
name
,
weight
in
self
.
l
in
ear_weights
.
items
():
if
isinstance
(
weight
,
torch
.
Tensor
):
self
.
register_parameter
(
name
,
weight
)
set_weight_attrs
(
weight
,
{
"
weight_loader
"
:
self
.
weight_loader
}
)
self
.
linear_method
.
create_weights
(
self
,
self
.
input_size
,
self
.
output_size
_per_partition
,
self
.
in
put_size
,
self
.
output_size
,
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
...
...
@@ -211,8 +215,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
output_parallel
=
self
.
linear_method
.
apply_weights
(
self
.
linear_weights
,
input_
,
bias
)
output_parallel
=
self
.
linear_method
.
apply_weights
(
self
,
input_
,
bias
)
if
self
.
gather_output
:
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
...
...
@@ -523,13 +526,13 @@ class RowParallelLinear(torch.nn.Module):
if
linear_method
is
None
:
linear_method
=
UnquantizedLinearMethod
()
self
.
linear_method
=
linear_method
self
.
linear_weights
=
self
.
linear_method
.
create_weights
(
self
.
input_size_per_partition
,
self
.
output_size
,
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
for
name
,
weight
in
self
.
l
in
ear_weights
.
items
():
if
isinstance
(
weight
,
torch
.
Tensor
):
self
.
register_parameter
(
name
,
weight
)
set_weight_attrs
(
weight
,
{
"
weight_loader
"
:
self
.
weight_loader
}
)
self
.
linear_method
.
create_weights
(
self
,
self
.
input_size_per_partition
,
self
.
output_size
,
self
.
in
put_size
,
self
.
output_size
,
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
...
...
@@ -569,7 +572,7 @@ class RowParallelLinear(torch.nn.Module):
# Matrix multiply.
output_parallel
=
self
.
linear_method
.
apply_weights
(
self
.
linear_weights
,
input_parallel
)
self
,
input_parallel
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
a10d3056
...
...
@@ -79,10 +79,11 @@ class AWQLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
AWQConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
input_size_per_partition
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]
:
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
)
:
if
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
...
...
@@ -136,19 +137,21 @@ class AWQLinearMethod(LinearMethodBase):
"input_dim"
:
0
,
"output_dim"
:
1
,
})
return
{
"qweight"
:
qweight
,
"qzeros"
:
qzeros
,
"scales"
:
scales
,
}
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
Any
]
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
weights
[
"
qweight
"
]
scales
=
weights
[
"
scales
"
]
qzeros
=
weights
[
"
qzeros
"
]
qweight
=
layer
.
qweight
scales
=
layer
.
scales
qzeros
=
layer
.
qzeros
pack_factor
=
self
.
quant_config
.
pack_factor
out_shape
=
(
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
]
*
pack_factor
,
))
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
...
...
@@ -163,5 +166,5 @@ class AWQLinearMethod(LinearMethodBase):
out
=
ops
.
awq_gemm
(
reshaped_x
,
qweight
,
scales
,
qzeros
,
pack_factor
)
if
bias
is
not
None
:
out
=
out
+
bias
out
.
add_
(
bias
)
return
out
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/gptq.py
View file @
a10d3056
...
...
@@ -89,12 +89,14 @@ class GPTQLinearMethod(LinearMethodBase):
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
)
->
Dict
[
str
,
Any
]:
**
extra_weight_attrs
,
):
del
output_size
# Unused.
if
input_size_per_partition
%
self
.
quant_config
.
group_size
!=
0
:
raise
ValueError
(
...
...
@@ -179,37 +181,40 @@ class GPTQLinearMethod(LinearMethodBase):
"input_dim"
:
scale_and_zero_input_dim
,
"output_dim"
:
1
,
})
return
{
"qweight"
:
qweight
,
"g_idx"
:
g_idx
,
"qzeros"
:
qzeros
,
"scales"
:
scales
,
"exllama_state"
:
exllama_state
,
}
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
set_weight_attrs
(
g_idx
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
exllama_state
=
exllama_state
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
Any
]
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
weights
[
"
qweight
"
]
qweight
=
layer
.
qweight
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if
weights
[
"
exllama_state
"
]
==
ExllamaState
.
UNINITIALIZED
:
if
layer
.
exllama_state
==
ExllamaState
.
UNINITIALIZED
:
if
self
.
quant_config
.
desc_act
:
weights
[
"g_idx"
]
=
torch
.
argsort
(
weights
[
"g_idx"
]).
to
(
torch
.
int
)
layer
.
g_idx
.
data
=
torch
.
argsort
(
layer
.
g_idx
).
to
(
torch
.
int
)
else
:
weights
[
"g_idx"
]
=
torch
.
empty
((
1
,
1
),
device
=
"meta"
)
weights
[
"exllama_state"
]
=
ExllamaState
.
READY
ops
.
gptq_shuffle
(
weights
[
"qweight"
],
weights
[
"g_idx"
],
layer
.
g_idx
.
data
=
torch
.
empty
((
0
,
),
device
=
layer
.
g_idx
.
device
)
layer
.
exllama_state
=
ExllamaState
.
READY
ops
.
gptq_shuffle
(
layer
.
qweight
,
layer
.
g_idx
,
self
.
quant_config
.
weight_bits
)
output
=
ops
.
gptq_gemm
(
reshaped_x
,
weights
[
"qweight"
],
weights
[
"qzeros"
],
weights
[
"scales"
],
weights
[
"g_idx"
],
weights
[
"exllama_state"
]
==
ExllamaState
.
READY
,
output
=
ops
.
gptq_gemm
(
reshaped_x
,
layer
.
qweight
,
layer
.
qzeros
,
layer
.
scales
,
layer
.
g_idx
,
layer
.
exllama_state
==
ExllamaState
.
READY
,
self
.
quant_config
.
weight_bits
)
if
bias
is
not
None
:
output
=
output
+
bias
output
.
add_
(
bias
)
return
output
.
reshape
(
out_shape
)
vllm/model_executor/layers/quantization/marlin.py
View file @
a10d3056
...
...
@@ -91,12 +91,14 @@ class MarlinLinearMethod(LinearMethodBase):
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
)
->
Dict
[
str
,
Any
]:
**
extra_weight_attrs
,
):
del
output_size
# Unused.
if
params_dtype
!=
torch
.
float16
:
...
...
@@ -187,21 +189,22 @@ class MarlinLinearMethod(LinearMethodBase):
dtype
=
torch
.
int
),
requires_grad
=
False
)
return
{
"B"
:
qweight
,
"s"
:
scales
,
"workspace"
:
workspace
,
}
layer
.
register_parameter
(
"B"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"s"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
register_parameter
(
"workspace"
,
workspace
)
set_weight_attrs
(
workspace
,
extra_weight_attrs
)
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
Any
]
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
qweight
=
weights
[
"B"
]
scales
=
weights
[
"s"
]
workspace
=
weights
[
"
workspace
"
]
qweight
=
layer
.
B
scales
=
layer
.
s
workspace
=
layer
.
workspace
x_2d
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
...
...
vllm/model_executor/layers/quantization/squeezellm.py
View file @
a10d3056
...
...
@@ -68,10 +68,11 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quant_config
:
SqueezeLLMConfig
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
input_size_per_partition
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_size_per_partition
:
int
,
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
)
->
Dict
[
str
,
Any
]
:
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
)
:
if
input_size_per_partition
%
self
.
quant_config
.
pack_factor
!=
0
:
raise
ValueError
(
"The input size is not aligned with the quantized "
...
...
@@ -103,17 +104,18 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
set_weight_attrs
(
lookup_table
,
{
"output_dim"
:
0
,
})
return
{
"qweight"
:
qweight
,
"lookup_table"
:
lookup_table
,
}
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"lookup_table"
,
lookup_table
)
set_weight_attrs
(
lookup_table
,
extra_weight_attrs
)
def
apply_weights
(
self
,
weights
:
Dict
[
str
,
Any
]
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
weights
[
"
qweight
"
]
lookup_table
=
weights
[
"
lookup_table
"
]
qweight
=
layer
.
qweight
lookup_table
=
layer
.
lookup_table
out_shape
=
x
.
shape
[:
-
1
]
+
(
qweight
.
shape
[
-
1
],
)
reshaped_x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
if
is_hip
():
...
...
@@ -126,5 +128,5 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
ops
.
squeezellm_gemm
(
reshaped_x
,
qweight
,
out
,
lookup_table
)
if
bias
is
not
None
:
out
=
out
+
bias
out
.
add_
(
bias
)
return
out
.
reshape
(
out_shape
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment