Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b2c31ca2
Unverified
Commit
b2c31ca2
authored
Aug 16, 2022
by
J-shang
Committed by
GitHub
Aug 16, 2022
Browse files
[Compression] Transformer pruning example (#5017)
parent
3eca23d5
Changes
32
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
237 additions
and
36 deletions
+237
-36
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
+2
-1
nni/algorithms/compression/v2/pytorch/pruning/tools/metrics_calculator.py
...ompression/v2/pytorch/pruning/tools/metrics_calculator.py
+1
-1
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
...ompression/v2/pytorch/pruning/tools/sparsity_allocator.py
+29
-14
nni/algorithms/compression/v2/pytorch/utils/evaluator.py
nni/algorithms/compression/v2/pytorch/utils/evaluator.py
+8
-2
nni/algorithms/compression/v2/pytorch/utils/external/__init__.py
...orithms/compression/v2/pytorch/utils/external/__init__.py
+0
-0
nni/algorithms/compression/v2/pytorch/utils/external/huggingface.py
...thms/compression/v2/pytorch/utils/external/huggingface.py
+141
-0
nni/algorithms/compression/v2/pytorch/utils/scaling.py
nni/algorithms/compression/v2/pytorch/utils/scaling.py
+3
-2
nni/common/graph_utils.py
nni/common/graph_utils.py
+13
-1
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+26
-9
nni/compression/pytorch/utils/mask_conflict.py
nni/compression/pytorch/utils/mask_conflict.py
+13
-1
pipelines/full-test-compression.yml
pipelines/full-test-compression.yml
+0
-4
test/algo/compression/v2/test_scaling.py
test/algo/compression/v2/test_scaling.py
+1
-1
No files found.
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
View file @
b2c31ca2
...
...
@@ -6,7 +6,8 @@ from datetime import datetime
import
logging
from
pathlib
import
Path
import
types
from
typing
import
List
,
Dict
,
Literal
,
Tuple
,
Optional
,
Callable
,
Union
from
typing
import
List
,
Dict
,
Tuple
,
Optional
,
Callable
,
Union
from
typing_extensions
import
Literal
import
json_tricks
import
torch
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/metrics_calculator.py
View file @
b2c31ca2
...
...
@@ -24,7 +24,7 @@ class StraightMetricsCalculator(MetricsCalculator):
for
module_name
,
targets_data
in
data
.
items
():
metrics
[
module_name
]
=
{}
for
target_name
,
target_data
in
targets_data
.
items
():
metrics
[
module_name
][
target_name
]
=
target_data
.
clone
().
detach
(
)
metrics
[
module_name
][
target_name
]
=
self
.
_get_scaler
(
module_name
,
target_name
).
shrink
(
target_data
)
return
metrics
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
View file @
b2c31ca2
...
...
@@ -31,13 +31,28 @@ class NormalSparsityAllocator(SparsityAllocator):
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
module_name
]
for
target_name
,
target_metric
in
targets_metric
.
items
():
sparsity_rate
=
wrapper
.
config
[
'total_sparsity'
]
prune_num
=
int
(
sparsity_rate
*
target_metric
.
numel
())
if
prune_num
!=
0
:
threshold
=
torch
.
topk
(
target_metric
.
reshape
(
-
1
),
prune_num
,
largest
=
False
)[
0
].
max
()
shrinked_mask
=
torch
.
gt
(
target_metric
,
threshold
).
type_as
(
target_metric
)
else
:
# target_metric should have the same size as shrinked_mask
shrinked_mask
=
torch
.
ones_like
(
target_metric
)
flatten_metric
=
target_metric
.
reshape
(
-
1
)
kept_num
=
flatten_metric
.
numel
()
-
int
(
sparsity_rate
*
flatten_metric
.
numel
())
kept_indices
=
torch
.
topk
(
flatten_metric
,
kept_num
).
indices
shrinked_mask
=
torch
.
zeros_like
(
flatten_metric
).
scatter
(
0
,
kept_indices
,
1.0
).
reshape_as
(
target_metric
)
masks
[
module_name
][
target_name
]
=
self
.
_expand_mask
(
module_name
,
target_name
,
shrinked_mask
)
return
masks
class
ThresholdSparsityAllocator
(
SparsityAllocator
):
"""
Note: This allocator is an experimental allocator.
It takes 'total_sparsity' as threshold to mask the pruning target where metric is lower then threshold.
"""
def
common_target_masks_generation
(
self
,
metrics
:
Dict
[
str
,
Dict
[
str
,
Tensor
]])
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
masks
=
{}
# TODO: Support more target type in wrapper & config list refactor
for
module_name
,
targets_metric
in
metrics
.
items
():
masks
[
module_name
]
=
{}
wrapper
=
self
.
pruner
.
get_modules_wrapper
()[
module_name
]
for
target_name
,
target_metric
in
targets_metric
.
items
():
threshold
=
wrapper
.
config
[
'total_sparsity'
]
shrinked_mask
=
torch
.
gt
(
torch
.
sigmoid
(
target_metric
),
threshold
).
type_as
(
target_metric
)
masks
[
module_name
][
target_name
]
=
self
.
_expand_mask
(
module_name
,
target_name
,
shrinked_mask
)
return
masks
...
...
@@ -115,10 +130,10 @@ class GlobalSparsityAllocator(SparsityAllocator):
assert
global_sparsity_rate
==
wrapper
.
config
[
'total_sparsity'
]
# find the largest metric value among all metrics
max_metric_value
=
list
(
list
(
metrics
.
values
())[
0
].
values
())[
0
].
max
()
max_metric_value
=
list
(
list
(
metrics
.
values
())[
0
].
values
())[
0
].
max
()
.
item
()
for
targets_metric
in
metrics
.
values
():
for
target_metric
in
targets_metric
.
values
():
max_metric_value
=
max_metric_value
if
max_metric_value
>=
target_metric
.
max
()
else
target_metric
.
max
()
max_metric_value
=
max_metric_value
if
max_metric_value
>=
target_metric
.
max
()
.
item
()
else
target_metric
.
max
()
.
item
()
# prevent each module from being over-pruned, prevent ratio is 'max_sparsity_per_layer'
for
module_name
,
targets_metric
in
metrics
.
items
():
...
...
@@ -127,10 +142,10 @@ class GlobalSparsityAllocator(SparsityAllocator):
max_sparsity
=
wrapper
.
config
.
get
(
'max_sparsity_per_layer'
,
{}).
get
(
module_name
,
0.99
)
assert
0
<=
max_sparsity
<=
1
old_target_mask
:
Tensor
=
getattr
(
wrapper
,
f
'
{
target_name
}
_mask'
)
expand_times
=
old_target_mask
.
numel
()
//
target_metric
.
numel
(
)
max
_pruning_numel
=
int
(
max_sparsity
*
target_metric
.
numel
())
*
expand_times
threshold
=
torch
.
topk
(
target_metric
.
reshape
(
-
1
),
max_pruning_numel
,
largest
=
False
)[
0
].
max
()
metrics
[
module_name
][
target_name
]
=
torch
.
where
(
target_metric
<=
threshold
,
target_metric
,
max_metric_value
)
flatten_metric
=
target_metric
.
reshape
(
-
1
)
protected
_pruning_numel
=
target_metric
.
numel
()
-
int
(
max_sparsity
*
target_metric
.
numel
())
protected_indices
=
torch
.
topk
(
flatten_metric
,
protected_pruning_numel
).
indices
metrics
[
module_name
][
target_name
]
=
flatten_metric
.
scatter
(
0
,
protected_indices
,
max_metric_value
).
reshape_as
(
target_metric
)
# build the global_matric & calculate global threshold
metric_list
=
[]
...
...
@@ -207,7 +222,7 @@ class DependencyAwareAllocator(SparsityAllocator):
fused_metrics
=
self
.
_metric_fuse
(
sub_metrics
)
for
target_name
,
fused_metric
in
fused_metrics
.
items
():
sparsity_rates
=
{
module_name
:
self
.
pruner
.
get_modules_wrapper
()[
module_name
].
config
[
'total_sparsity'
]
\
sparsity_rates
=
{
module_name
:
self
.
pruner
.
get_modules_wrapper
()[
module_name
].
config
[
'total_sparsity'
]
for
module_name
in
sub_metrics
.
keys
()}
min_sparsity_rate
=
min
(
sparsity_rates
.
values
())
...
...
nni/algorithms/compression/v2/pytorch/utils/evaluator.py
View file @
b2c31ca2
...
...
@@ -14,8 +14,13 @@ from torch.optim import Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
from
torch.utils.hooks
import
RemovableHandle
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks
import
Callback
try
:
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks
import
Callback
except
ImportError
:
LightingInstalled
=
False
else
:
LightingInstalled
=
True
from
nni.common
import
is_traceable
from
.constructor_helper
import
OptimizerConstructHelper
,
LRSchedulerConstructHelper
...
...
@@ -292,6 +297,7 @@ class LightningEvaluator(Evaluator):
def
__init__
(
self
,
trainer
:
pl
.
Trainer
,
data_module
:
pl
.
LightningDataModule
,
dummy_input
:
Any
|
None
=
None
):
assert
LightingInstalled
,
'pytorch_lightning is not installed.'
err_msg_p
=
'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
err_msg
=
err_msg_p
.
format
(
'pytorch_lightning.Trainer'
,
'pytorch_lightning.Trainer'
)
assert
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
),
err_msg
...
...
nni/algorithms/compression/v2/pytorch/utils/external/__init__.py
0 → 100644
View file @
b2c31ca2
nni/algorithms/compression/v2/pytorch/utils/external/huggingface.py
0 → 100644
View file @
b2c31ca2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
import
logging
import
re
from
typing
import
Tuple
from
torch.nn
import
Module
try
:
from
transformers
import
(
PreTrainedModel
,
BartConfig
,
BertConfig
,
T5Config
)
except
ImportError
:
TRANSFORMERS_INSTALLED
=
False
else
:
TRANSFORMERS_INSTALLED
=
True
from
nni.algorithms.compression.v2.pytorch.utils.attr
import
get_nested_attr
_logger
=
logging
.
getLogger
(
__name__
)
# huggingface transformers pretrained model parser supported: bart, bert, t5
def
parser_factory
(
model
:
Module
)
->
HuggingfaceModelParser
|
None
:
if
TRANSFORMERS_INSTALLED
and
isinstance
(
model
,
PreTrainedModel
):
cls2parser
=
{
BartConfig
:
HuggingfaceBartParser
,
BertConfig
:
HuggingfaceBertParser
,
T5Config
:
HuggingfaceT5Parser
}
type2parser
=
{
'bart'
:
HuggingfaceBartParser
,
'bert'
:
HuggingfaceBertParser
,
't5'
:
HuggingfaceT5Parser
}
if
hasattr
(
model
,
'config_class'
):
parser
=
cls2parser
.
get
(
getattr
(
model
,
'config_class'
))
elif
hasattr
(
model
,
'model_type'
):
parser
=
type2parser
.
get
(
getattr
(
model
,
'model_type'
))
else
:
parser
=
None
return
parser
else
:
return
None
class
HuggingfaceModelParser
:
# This class is used to verify that a module name belongs to a specific huggingface transformers pretrained model.
# Further, verify that the module with this name is some kind of special layer (QKVO or FFN).
TRANSFORMER_PREFIX
:
str
QKV
:
Tuple
[
str
,
...]
QKVO
:
Tuple
[
str
,
...]
FFN1
:
Tuple
[
str
,
...]
FFN2
:
Tuple
[
str
,
...]
ATTENTION
:
Tuple
[
str
,
...]
@
classmethod
def
is_huggingface_model
(
cls
,
model
:
Module
):
return
model
.
__module__
.
split
(
'.'
)[
0
]
==
'transformers'
@
classmethod
def
is_attention
(
cls
,
module_name
:
str
,
include_output
:
bool
=
True
)
->
bool
:
patterns
=
cls
.
QKVO
if
include_output
else
cls
.
QKV
for
pattern
in
patterns
:
if
pattern
in
module_name
:
return
True
return
False
@
classmethod
def
is_ffn
(
cls
,
module_name
:
str
,
ffn_num
:
int
=
1
)
->
bool
:
if
cls
.
is_attention
(
module_name
):
return
False
if
ffn_num
==
1
:
for
pattern
in
cls
.
FFN1
:
if
pattern
in
module_name
:
return
True
if
ffn_num
==
2
:
for
pattern
in
cls
.
FFN2
:
if
pattern
in
module_name
:
return
True
return
False
@
classmethod
def
get_num_heads
(
cls
,
module_name
:
str
,
model
:
Module
)
->
int
:
if
cls
.
is_attention
(
module_name
,
include_output
=
True
):
for
pattern
in
cls
.
ATTENTION
:
match
=
re
.
search
(
pattern
,
module_name
)
if
match
:
attention_module_name
=
module_name
[
0
:
match
.
span
()[
1
]]
module
=
get_nested_attr
(
model
,
attention_module_name
)
if
hasattr
(
module
,
'num_attention_heads'
):
num_heads
=
module
.
num_attention_heads
elif
hasattr
(
module
,
'num_heads'
):
num_heads
=
module
.
num_heads
elif
hasattr
(
module
,
'n_heads'
):
num_heads
=
module
.
n_heads
else
:
warn_msg
=
f
'Can not get the heads number of attention layer :
{
attention_module_name
}
.'
_logger
.
warning
(
warn_msg
)
num_heads
=
0
return
num_heads
return
0
else
:
warn_msg
=
f
'The layer `
{
module_name
}
` might not an (Q|K|V) attention layer.'
_logger
.
warning
(
warn_msg
)
return
0
class
HuggingfaceBertParser
(
HuggingfaceModelParser
):
TRANSFORMER_PREFIX
=
r
'bert\.encoder\.layer\.[0-9]+\.'
QKV
=
(
'attention.self.query'
,
'attention.self.key'
,
'attention.self.value'
)
QKVO
=
QKV
+
(
'attention.output.dense'
,)
FFN1
=
(
'intermediate.dense'
,)
FFN2
=
(
'output.dense'
,)
ATTENTION
=
(
'attention.self'
,)
class
HuggingfaceBartParser
(
HuggingfaceModelParser
):
TRANSFORMER_PREFIX
=
r
'(en|de)coder\.layer\.[0-9]+\.'
QKV
=
(
'self_attn.q_proj'
,
'self_attn.k_proj'
,
'self_attn.v_proj'
,
'encoder_attn.q_proj'
,
'encoder_attn.k_proj'
,
'encoder_attn.v_proj'
)
QKVO
=
QKV
+
(
'self_attn.out_proj'
,
'encoder_attn.out_proj'
)
FFN1
=
(
'fc1'
,)
FFN2
=
(
'fc2'
,)
ATTENTION
=
(
'self_attn'
,
'encoder_attn'
)
class
HuggingfaceT5Parser
(
HuggingfaceModelParser
):
TRANSFORMER_PREFIX
=
r
'(en|de)coder\.block\.[0-9]+\.layer\.[0-9]+.'
QKV
=
(
'SelfAttention.q'
,
'SelfAttention.k'
,
'SelfAttention.v'
,
'EncDecAttention.q'
,
'EncDecAttention.k'
,
'EncDecAttention.v'
)
QKVO
=
QKV
+
(
'SelfAttention.o'
,
'EncDecAttention.o'
)
FFN1
=
(
'DenseReluDense.wi'
,)
FFN2
=
(
'DenseReluDense.wo'
,)
ATTENTION
=
(
'SelfAttention'
,
'EncDecAttention'
)
nni/algorithms/compression/v2/pytorch/utils/scaling.py
View file @
b2c31ca2
...
...
@@ -122,8 +122,9 @@ class Scaling:
permute_dims
=
[
2
*
_
for
_
in
range
(
len
(
kernel_size
))]
+
[
2
*
_
+
1
for
_
in
range
(
len
(
kernel_size
))]
converted_target
=
target
.
reshape
(
reshape_size
).
permute
(
permute_dims
).
reshape
(
final_size
+
[
-
1
])
# step 2: reduce the converted_target last dim with a certain way, by default is converted_target.sum(-1).
result
=
reduce_func
(
converted_target
)
if
reduce_func
else
converted_target
.
sum
(
-
1
)
# step 2: reduce the converted_target last dim with a certain way, by default is converted_target.mean(-1).
# `sum` does not take into account the metric scale problem, it is better to use `mean` here.
result
=
reduce_func
(
converted_target
)
if
reduce_func
else
converted_target
.
mean
(
-
1
)
# step 3: reduce the dims where kernel_size is -1.
# e.g., target size is [10, 40], kernel_size is [-1, 4], result size is [1, 10], then reduce result to size [10].
...
...
nni/common/graph_utils.py
View file @
b2c31ca2
...
...
@@ -75,6 +75,18 @@ class TorchGraph:
if
torch
.
__version__
>=
'1.6.0'
:
# only pytorch with version greater than 1.6.0 has the strict option
kw_args
[
'strict'
]
=
False
try
:
import
pytorch_lightning
as
pl
except
ImportError
:
is_lightning_module
=
False
else
:
if
isinstance
(
model
,
pl
.
LightningModule
):
is_lightning_module
=
True
else
:
is_lightning_module
=
False
if
is_lightning_module
:
self
.
trace
=
model
.
to_torchscript
(
method
=
"trace"
,
example_inputs
=
dummy_input
,
**
kw_args
)
else
:
self
.
trace
=
torch
.
jit
.
trace
(
model
,
dummy_input
,
**
kw_args
)
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
model
.
train
(
training
)
...
...
nni/compression/pytorch/speedup/compress_modules.py
View file @
b2c31ca2
...
...
@@ -31,6 +31,7 @@ replace_module = {
'SELU'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'CELU'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'GELU'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'GELUActivation'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Sigmoid'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'SiLU'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'Mish'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
...
...
@@ -74,6 +75,7 @@ def convert_to_coarse_mask(t_mask, dim):
n_dims
=
len
(
shape
)
dim_list
=
list
(
range
(
n_dims
))
# try to reduce the mask from the dim-th dimension
dim
=
dim
if
dim
>=
0
else
n_dims
+
dim
dim_list
.
remove
(
dim
)
t_merged
=
torch
.
sum
(
t_mask
,
dim_list
)
...
...
@@ -190,12 +192,9 @@ def replace_linear(linear, masks):
in_mask
=
in_masks
[
0
]
weight_mask
=
weight_mask
[
'weight'
]
# the input of the linear may have two dimensions(CV models) or three
# dimensions(Bert, for example)
n_dim
=
len
(
in_mask
.
size
())
# N C K
pruned_in
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
n_dim
-
1
)
pruned_out
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
n_dim
-
1
)
pruned_in
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
-
1
)
pruned_out
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
-
1
)
n_remained_in
=
weight_mask
.
size
(
1
)
-
pruned_in
.
size
(
0
)
n_remained_out
=
weight_mask
.
size
(
0
)
-
pruned_out
.
size
(
0
)
remained_in
,
remained_out
=
remained_in
.
to
(
...
...
@@ -610,11 +609,29 @@ def replace_layernorm(layernorm, masks):
if
len
(
in_masks
)
!=
1
:
raise
InputsNumberError
()
in_mask
=
in_masks
[
0
]
dense_shape
=
convert_dense_shape
(
in_mask
)
norm_shape
=
layernorm
.
normalized_shape
dim_n
=
len
(
dense_shape
)
-
len
(
norm_shape
)
return
nn
.
LayerNorm
(
dense_shape
[
dim_n
:],
layernorm
.
eps
,
layernorm
.
elementwise_affine
)
old_normalized_shape
=
layernorm
.
normalized_shape
new_normalized_shape
=
[]
remained_list
=
[]
for
i
in
range
(
-
len
(
old_normalized_shape
),
0
):
pruned
,
remained
=
convert_to_coarse_mask
(
in_mask
,
i
)
new_normalized_shape
.
append
(
old_normalized_shape
[
i
]
-
pruned
.
size
()[
0
])
remained_list
.
append
(
remained
)
new_layernorm
=
nn
.
LayerNorm
(
tuple
(
new_normalized_shape
),
layernorm
.
eps
,
layernorm
.
elementwise_affine
)
if
new_layernorm
.
elementwise_affine
:
new_layernorm
.
to
(
layernorm
.
weight
.
device
)
# NOTE: should we keep the weight & bias?
with
torch
.
no_grad
():
tmp_weight_data
=
layernorm
.
weight
.
data
tmp_bias_data
=
layernorm
.
bias
.
data
for
i
,
remained
in
enumerate
(
remained_list
):
tmp_weight_data
=
torch
.
index_select
(
tmp_weight_data
,
i
,
remained
)
tmp_bias_data
=
torch
.
index_select
(
tmp_bias_data
,
i
,
remained
)
new_layernorm
.
weight
.
data
=
tmp_weight_data
new_layernorm
.
bias
.
data
=
tmp_bias_data
return
new_layernorm
def
replace_embedding
(
embedding
,
masks
):
"""
...
...
nni/compression/pytorch/utils/mask_conflict.py
View file @
b2c31ca2
...
...
@@ -45,6 +45,18 @@ def fix_mask_conflict(masks, model, dummy_input, traced=None):
if
torch
.
__version__
>=
'1.6.0'
:
# only pytorch with version greater than 1.6.0 has the strict option
kw_args
[
'strict'
]
=
False
try
:
import
pytorch_lightning
as
pl
except
ImportError
:
is_lightning_module
=
False
else
:
if
isinstance
(
model
,
pl
.
LightningModule
):
is_lightning_module
=
True
else
:
is_lightning_module
=
False
if
is_lightning_module
:
traced
=
model
.
to_torchscript
(
method
=
"trace"
,
example_inputs
=
dummy_input
,
**
kw_args
)
else
:
traced
=
torch
.
jit
.
trace
(
model
,
dummy_input
,
**
kw_args
)
model
.
train
(
training
)
...
...
pipelines/full-test-compression.yml
View file @
b2c31ca2
...
...
@@ -42,10 +42,6 @@ stages:
platform
:
ubuntu-latest-gpu
python_env
:
venv
-
script
:
|
python -m pip install "pytorch-lightning<1.7"
displayName
:
Pin PytorchLightning version
-
template
:
templates/install-nni.yml
-
template
:
templates/download-test-data.yml
...
...
test/algo/compression/v2/test_scaling.py
View file @
b2c31ca2
...
...
@@ -8,7 +8,7 @@ from nni.algorithms.compression.v2.pytorch.utils.scaling import Scaling
def
test_scaling
():
data
=
torch
.
tensor
([
_
for
_
in
range
(
100
)]).
reshape
(
10
,
10
)
data
=
torch
.
tensor
([
_
for
_
in
range
(
100
)]
,
dtype
=
torch
.
float32
).
reshape
(
10
,
10
)
scaler
=
Scaling
([
5
],
kernel_padding_mode
=
'front'
)
shrinked_data
=
scaler
.
shrink
(
data
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment