Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
5d4ab5dc
Commit
5d4ab5dc
authored
Aug 17, 2023
by
Casper Hansen
Browse files
Refactor entry.py [WIP]
parent
290f45e7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
99 additions
and
280 deletions
+99
-280
awq/entry.py
awq/entry.py
+33
-195
awq/models/base.py
awq/models/base.py
+62
-7
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+4
-78
No files found.
awq/entry.py
View file @
5d4ab5dc
from
lm_eval
import
evaluator
,
tasks
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
AutoConfig
import
torch
import
argparse
import
os
import
os
import
json
import
torch
from
accelerate
import
init_empty_weights
,
infer_auto_device_map
,
dispatch_model
,
load_checkpoint_in_model
from
accelerate.utils.modeling
import
get_balanced_memory
from
awq.utils.parallel
import
auto_parallel
from
awq.quantize.quantizer
import
pseudo_quantize_model_weight
,
real_quantize_model_weight
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
awq.utils.utils
import
simple_dispatch_model
from
awq.quantize.auto_clip
import
apply_clip
from
awq.quantize.auto_clip
import
apply_clip
from
awq.quantize.auto_scale
import
apply_scale
from
awq.quantize.auto_scale
import
apply_scale
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
AutoConfig
parser
=
argparse
.
ArgumentParser
()
max_memory
=
[
v
.
split
(
':'
)
for
v
in
(
None
or
[])]
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
'path of the hf model'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
,
help
=
'batch size'
)
parser
.
add_argument
(
"--tasks"
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
"--output_path"
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
'--num_fewshot'
,
type
=
int
,
default
=
0
)
# model config
parser
.
add_argument
(
'--parallel'
,
action
=
'store_true'
,
help
=
"enable model parallelism"
)
# max memory to offload larger models to CPU
parser
.
add_argument
(
'--max_memory'
,
type
=
str
,
nargs
=
'*'
,
help
=
"List of device_id:max_memory pairs to be parsed into a dictionary; "
\
+
"Example: 0:10GiB 1:10GiB cpu:30GiB; "
\
+
"mode details here: "
\
+
"https://huggingface.co/docs/accelerate/usage_guides/big_modeling"
)
parser
.
add_argument
(
'--auto_parallel'
,
action
=
'store_true'
,
help
=
"automatically set parallel and batch_size"
)
# quantization config
parser
.
add_argument
(
'--w_bit'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--q_group_size'
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
'--no_zero_point'
,
action
=
'store_true'
,
help
=
"disable zero_point"
)
parser
.
add_argument
(
'--q_backend'
,
type
=
str
,
default
=
"fake"
,
choices
=
[
"fake"
,
"real"
])
# save/load real quantized weights
parser
.
add_argument
(
'--dump_quant'
,
type
=
str
,
default
=
None
,
help
=
'save quantized model'
)
parser
.
add_argument
(
'--load_quant'
,
type
=
str
,
default
=
None
,
help
=
'load quantized model'
)
# apply/save/load awq
parser
.
add_argument
(
'--run_awq'
,
action
=
'store_true'
,
help
=
"perform awq search process"
)
parser
.
add_argument
(
'--dump_awq'
,
type
=
str
,
default
=
None
,
help
=
"save the awq search results"
)
parser
.
add_argument
(
'--load_awq'
,
type
=
str
,
default
=
None
,
help
=
"load the awq search results"
)
args
=
parser
.
parse_args
()
max_memory
=
[
v
.
split
(
':'
)
for
v
in
(
args
.
max_memory
or
[])]
max_memory
=
{(
int
(
k
)
if
k
.
isdigit
()
else
k
):
v
for
k
,
v
in
max_memory
}
max_memory
=
{(
int
(
k
)
if
k
.
isdigit
()
else
k
):
v
for
k
,
v
in
max_memory
}
if
args
.
auto_parallel
:
gpu_list
=
auto_parallel
(
args
)
# get quantization config (apart from w_bit)
q_config
=
{
"zero_point"
:
not
args
.
no_zero_point
,
# by default True
"q_group_size"
:
args
.
q_group_size
,
# whether to use group quantization
}
print
(
"Quantization config:"
,
q_config
)
def
get_awq_model
(
model
):
def
get_awq_model
(
model
):
from
awq.models
import
MptAWQForCausalLM
from
awq.models
import
MptAWQForCausalLM
...
@@ -73,149 +15,45 @@ def get_awq_model(model):
...
@@ -73,149 +15,45 @@ def get_awq_model(model):
else
:
else
:
raise
NotImplementedError
(
type
(
model
))
raise
NotImplementedError
(
type
(
model
))
def
build_model_and_enc
(
model_path
):
def
load_unquantized
(
model_path
):
if
not
os
.
path
.
exists
(
model_path
):
# look into ssd
raise
FileNotFoundError
(
f
"
{
model_path
}
not found!"
)
print
(
f
"* Building model
{
model_path
}
"
)
# all hf model
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
if
"mpt"
in
config
.
__class__
.
__name__
.
lower
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
config
.
tokenizer_name
,
trust_remote_code
=
True
)
enc
=
AutoTokenizer
.
from_pretrained
(
config
.
tokenizer_name
,
trust_remote_code
=
True
)
else
:
enc
=
AutoTokenizer
.
from_pretrained
(
model_path
,
use_fast
=
False
,
trust_remote_code
=
True
)
if
args
.
load_quant
:
# directly load quantized weights
print
(
"Loading pre-computed quantized weights..."
)
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
)
real_quantize_model_weight
(
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
init_only
=
True
)
model
.
tie_weights
()
# Infer device map
kwargs
=
{
"max_memory"
:
max_memory
}
if
len
(
max_memory
)
else
{}
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
[
"OPTDecoderLayer"
,
"LlamaDecoderLayer"
,
"BloomBlock"
,
"MPTBlock"
,
"DecoderLayer"
],
**
kwargs
)
# Load checkpoint in the model
load_checkpoint_in_model
(
model
,
checkpoint
=
args
.
load_quant
,
device_map
=
device_map
,
offload_state_dict
=
True
,
)
# Dispatch model
model
=
simple_dispatch_model
(
model
,
device_map
=
device_map
)
model
.
eval
()
else
:
# fp16 to quantized
args
.
run_awq
&=
not
args
.
load_awq
# if load_awq, no need to run awq
# Init model on CPU:
kwargs
=
{
"torch_dtype"
:
torch
.
float16
,
"low_cpu_mem_usage"
:
True
}
kwargs
=
{
"torch_dtype"
:
torch
.
float16
,
"low_cpu_mem_usage"
:
True
}
model
=
AutoModelForCausalLM
.
from_pretrained
(
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
config
=
config
,
trust_remote_code
=
True
,
**
kwargs
)
model_path
,
config
=
config
,
trust_remote_code
=
True
,
**
kwargs
)
model
.
eval
()
model
.
eval
()
if
args
.
run_awq
:
return
model
,
tokenizer
assert
args
.
dump_awq
,
"Please save the awq results with --dump_awq"
awq_model
=
get_awq_model
(
model
)
awq_results
=
awq_model
.
quantize
(
model
,
enc
,
args
.
w_bit
,
q_config
)
if
args
.
dump_awq
:
dirpath
=
os
.
path
.
dirname
(
args
.
dump_awq
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
torch
.
save
(
awq_results
,
args
.
dump_awq
)
print
(
"AWQ results saved at"
,
args
.
dump_awq
)
exit
(
0
)
if
args
.
load_awq
:
def
load_search_result_into_memory
(
model
,
search_path
):
print
(
"Loading pre-computed AWQ results from"
,
args
.
load_awq
)
awq_results
=
torch
.
load
(
search_path
,
map_location
=
"cpu"
)
awq_results
=
torch
.
load
(
args
.
load_awq
,
map_location
=
"cpu"
)
apply_scale
(
model
,
awq_results
[
"scale"
])
apply_scale
(
model
,
awq_results
[
"scale"
])
apply_clip
(
model
,
awq_results
[
"clip"
])
apply_clip
(
model
,
awq_results
[
"clip"
])
# weight quantization
def
run_search
(
model
,
dump_path
):
if
args
.
w_bit
is
not
None
:
model
,
tokenizer
=
load_unquantized
(
model_path
)
if
args
.
q_backend
==
"fake"
:
awq_model
=
get_awq_model
(
model
)
assert
args
.
dump_quant
is
None
,
\
awq_results
=
awq_model
.
quantize
(
model
,
tokenizer
,
w_bit
=
4
,
q_config
=
q_config
,
run_search
=
True
,
run_quant
=
False
)
"Need to use real quantization to dump quantized weights"
pseudo_quantize_model_weight
(
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
)
elif
args
.
q_backend
==
"real"
:
# real quantization
real_quantize_model_weight
(
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
)
if
args
.
dump_quant
:
dirpath
=
os
.
path
.
dirname
(
args
.
dump_quant
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
print
(
f
"Saving the quantized model at
{
args
.
dump_quant
}
..."
)
torch
.
save
(
model
.
cpu
().
state_dict
(),
args
.
dump_quant
)
exit
(
0
)
else
:
raise
NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation
kwargs
=
{
"max_memory"
:
get_balanced_memory
(
model
,
max_memory
if
len
(
max_memory
)
>
0
else
None
)}
device_map
=
infer_auto_device_map
(
model
,
# TODO: can we remove this?
no_split_module_classes
=
[
"OPTDecoderLayer"
,
"LlamaDecoderLayer"
,
"BloomBlock"
,
"MPTBlock"
,
"DecoderLayer"
],
**
kwargs
)
model
=
dispatch_model
(
model
,
device_map
=
device_map
)
return
model
,
enc
def
main
():
if
args
.
output_path
is
not
None
and
os
.
path
.
exists
(
args
.
output_path
):
# print(f"Results {args.output_path} already generated. Exit.")
print
(
f
"Results
{
args
.
output_path
}
already generated. Overwrite."
)
# exit()
if
args
.
dump_awq
and
os
.
path
.
exists
(
args
.
dump_awq
):
print
(
f
"Found existing AWQ results
{
args
.
dump_awq
}
, exit."
)
exit
()
# a hack here to auto set model group
model
,
enc
=
build_model_and_enc
(
args
.
model_path
)
if
args
.
tasks
is
not
None
:
task_names
=
args
.
tasks
.
split
(
","
)
lm_eval_model
=
LMEvalAdaptor
(
args
.
model_path
,
model
,
enc
,
args
.
batch_size
)
dirpath
=
os
.
path
.
dirname
(
dump_path
)
results
=
evaluator
.
simple_evaluate
(
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
model
=
lm_eval_model
,
torch
.
save
(
awq_results
,
dump_path
)
tasks
=
task_names
,
batch_size
=
args
.
batch_size
,
no_cache
=
True
,
num_fewshot
=
args
.
num_fewshot
,
)
print
(
evaluator
.
make_table
(
results
))
def
run_quant
(
model
,
search_path
,
dump_path
):
model
,
tokenizer
=
load_unquantized
(
model_path
)
load_search_result_into_memory
(
model
,
search_path
)
if
args
.
output_path
is
not
None
:
awq_model
=
get_awq_model
(
model
)
os
.
makedirs
(
os
.
path
.
dirname
(
args
.
output_path
),
exist_ok
=
True
)
awq_model
.
quantize
(
model
,
w_bit
=
4
,
q_config
=
q_config
,
run_search
=
False
,
run_quant
=
True
)
# otherwise cannot save
results
[
"config"
][
"model"
]
=
args
.
model_path
with
open
(
args
.
output_path
,
"w"
)
as
f
:
json
.
dump
(
results
,
f
,
indent
=
2
)
dirpath
=
os
.
path
.
dirname
(
dump_path
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
torch
.
save
(
model
.
cpu
().
state_dict
(),
dump_path
)
if
__name__
==
'__main__'
:
model_path
=
"./mpt-7b-8k-chat"
main
()
search_path
=
"./mpt-7b-8k-chat/mpt-7b-8k-chat-awq-search.pt"
quant_path
=
"./mpt-7b-8k-chat/mpt-7b-8k-chat-w4-g128.pt"
q_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
}
awq/models/base.py
View file @
5d4ab5dc
import
gc
import
gc
import
tqdm
import
torch
import
torch
import
functools
import
functools
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
collections
import
defaultdict
from
collections
import
defaultdict
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
from
awq.quantize.quantizer
import
pseudo_quantize_tensor
from
awq.quantize.qmodule
import
WQLinear
,
ScaledActivation
class
BaseAWQForCausalLM
:
class
BaseAWQForCausalLM
:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
quantize
(
self
,
model
,
tokenizer
,
w_bit
,
q_config
,
n_samples
=
128
,
seqlen
=
512
,
def
quantize
(
self
,
model
,
tokenizer
=
None
,
w_bit
=
4
,
q_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
calib_data
=
"pileval"
):
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
False
,
run_quant
=
True
,
calib_data
=
"pileval"
,
init_only
=
False
):
if
run_search
:
self
.
_awq_search
(
model
,
tokenizer
,
w_bit
,
q_config
,
n_samples
=
n_samples
,
seqlen
=
seqlen
,
auto_scale
=
auto_scale
,
mse_range
=
mse_range
,
calib_data
=
calib_data
)
if
run_quant
:
self
.
_awq_quant
(
model
,
w_bit
,
q_config
,
init_only
)
def
_awq_quant
(
self
,
model
,
w_bit
,
q_config
,
init_only
):
assert
q_config
[
"zero_point"
],
"We only support zero_point quantization now."
layers
=
self
.
get_model_layers
(
model
)
# Run AWQ quantization
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"AWQ Quantization"
):
layer
=
layers
[
i
]
named_linears
=
get_named_linears
(
layer
)
if
not
isinstance
(
layer
.
ffn
.
act
,
ScaledActivation
):
param
=
next
(
layer
.
parameters
())
# get activation scale
scale_dict
=
self
.
get_act_for_scaling
(
layer
)
scale_like
=
torch
.
ones
(
scale_dict
[
'scale_shape'
],
dtype
=
param
.
dtype
,
device
=
param
.
device
)
# scale activation
scaled_act
=
ScaledActivation
(
scale_dict
[
'scale_layer'
],
scale_like
)
set_op_by_name
(
layer
,
scale_dict
[
'scale_name'
],
scaled_act
)
for
name
,
module
in
named_linears
.
items
():
if
init_only
:
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
True
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
else
:
module
.
cuda
()
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
n_bit
=
w_bit
,
get_scale_zp
=
True
,
**
q_config
)
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
.
cpu
()
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
def
_awq_search
(
self
,
model
,
tokenizer
,
w_bit
,
q_config
,
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
calib_data
=
"pileval"
):
layers
=
self
.
get_model_layers
(
model
)
layers
=
self
.
get_model_layers
(
model
)
samples
=
get_calib_dataset
(
samples
=
get_calib_dataset
(
...
@@ -62,8 +117,8 @@ class BaseAWQForCausalLM:
...
@@ -62,8 +117,8 @@ class BaseAWQForCausalLM:
"clip"
:
[],
"clip"
:
[],
}
}
#
solve
layer by layer
#
Run AWQ search
layer by layer
for
i
in
tqdm
.
tqdm
(
range
(
len
(
layers
)),
desc
=
"
Running AWQ...
"
):
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"
AWQ Search:
"
):
layer
=
layers
[
i
]
layer
=
layers
[
i
]
layer
=
layer
.
cuda
()
layer
=
layer
.
cuda
()
named_linears
=
get_named_linears
(
layer
)
named_linears
=
get_named_linears
(
layer
)
...
...
awq/quantize/quantizer.py
View file @
5d4ab5dc
import
torch
import
torch
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
gc
from
.qmodule
import
ScaledActivation
from
..utils.module
import
set_op_by_name
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
EMBEDDING_KEYWORDS
=
[
"embed"
]
EMBEDDING_KEYWORDS
=
[
"embed"
]
LM_HEAD_KEYWORDS
=
[
"lm_head"
,
"embed_out"
,
"output"
]
LM_HEAD_KEYWORDS
=
[
"lm_head"
,
"embed_out"
,
"output"
]
def
scale_activations
(
module
):
param
=
next
(
module
.
parameters
())
dtype
=
param
.
dtype
device
=
param
.
device
if
isinstance
(
module
,
BloomBlock
):
if
isinstance
(
module
.
mlp
.
gelu_impl
,
ScaledActivation
):
return
c
=
module
.
mlp
.
dense_h_to_4h
.
out_features
act
=
ScaledActivation
(
module
.
mlp
.
gelu_impl
,
torch
.
ones
(
c
,
dtype
=
dtype
,
device
=
device
)
)
set_op_by_name
(
module
,
"mlp.gelu_impl"
,
act
)
elif
'mptblock'
in
str
(
module
.
__class__
.
__name__
).
lower
():
if
isinstance
(
module
.
ffn
.
act
,
ScaledActivation
):
return
# get activation scale
scale_dict
=
MptAWQForCausalLM
().
get_act_for_scaling
(
module
)
scale_like
=
torch
.
ones
(
scale_dict
[
'scale_shape'
],
dtype
=
dtype
,
device
=
device
)
# scale activation
scaled_act
=
ScaledActivation
(
scale_dict
[
'scale_layer'
],
scale_like
)
set_op_by_name
(
module
,
scale_dict
[
'scale_name'
],
scaled_act
)
elif
'falcon'
in
str
(
module
.
__class__
).
lower
():
if
isinstance
(
module
.
mlp
.
act
,
ScaledActivation
):
return
c
=
module
.
mlp
.
dense_h_to_4h
.
out_features
act
=
ScaledActivation
(
module
.
mlp
.
act
,
torch
.
ones
(
c
,
dtype
=
dtype
,
device
=
device
)
)
set_op_by_name
(
module
,
"mlp.act"
,
act
)
# core quantization method (simulated quantization)
# core quantization method (simulated quantization)
def
pseudo_quantize_tensor
(
w
,
n_bit
=
8
,
def
pseudo_quantize_tensor
(
w
,
n_bit
=
8
,
zero_point
=
True
,
q_group_size
=-
1
,
zero_point
=
True
,
q_group_size
=-
1
,
...
@@ -107,38 +63,8 @@ def pseudo_quantize_model_weight(
...
@@ -107,38 +63,8 @@ def pseudo_quantize_model_weight(
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
real_quantize_model_weight
(
def
real_quantize_model_weight
(
model
,
awq_model
):
model
,
w_bit
,
q_config
,
layers
=
awq_model
.
get_model_layers
(
model
)
init_only
=
False
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"real weight quantization..."
):
):
from
.qmodule
import
WQLinear
from
.pre_quant
import
get_blocks
,
get_named_linears
assert
q_config
[
"zero_point"
],
"We only support zero_point quantization now."
layers
=
get_blocks
(
model
)
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"real weight quantization..."
+
(
"(init only)"
if
init_only
else
""
)):
layer
=
layers
[
i
]
layer
=
layers
[
i
]
named_linears
=
get_named_linears
(
layer
)
del
layer
scale_activations
(
layer
)
for
name
,
module
in
named_linears
.
items
():
if
init_only
:
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
True
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
else
:
module
.
cuda
()
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
n_bit
=
w_bit
,
get_scale_zp
=
True
,
**
q_config
)
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
.
cpu
()
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment