Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
8520d496
Unverified
Commit
8520d496
authored
May 05, 2025
by
Connector Switch
Committed by
GitHub
May 05, 2025
Browse files
[Feature] Implement tiled VAE encoding/decoding for Wan model. (#11414)
* implement tiled encode/decode * address review comments
parent
a674914f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
320 additions
and
14 deletions
+320
-14
src/diffusers/models/autoencoders/autoencoder_kl_wan.py
src/diffusers/models/autoencoders/autoencoder_kl_wan.py
+243
-13
tests/models/autoencoders/test_models_autoencoder_wan.py
tests/models/autoencoders/test_models_autoencoder_wan.py
+77
-1
No files found.
src/diffusers/models/autoencoders/autoencoder_kl_wan.py
View file @
8520d496
...
@@ -730,6 +730,76 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
...
@@ -730,6 +730,76 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
base_dim
,
z_dim
,
dim_mult
,
num_res_blocks
,
attn_scales
,
self
.
temperal_upsample
,
dropout
base_dim
,
z_dim
,
dim_mult
,
num_res_blocks
,
attn_scales
,
self
.
temperal_upsample
,
dropout
)
)
self
.
spatial_compression_ratio
=
2
**
len
(
self
.
temperal_downsample
)
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
# to perform decoding of a single video latent at a time.
self
.
use_slicing
=
False
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
# intermediate tiles together, the memory requirement can be lowered.
self
.
use_tiling
=
False
# The minimal tile height and width for spatial tiling to be used
self
.
tile_sample_min_height
=
256
self
.
tile_sample_min_width
=
256
# The minimal distance between two spatial tiles
self
.
tile_sample_stride_height
=
192
self
.
tile_sample_stride_width
=
192
def
enable_tiling
(
self
,
tile_sample_min_height
:
Optional
[
int
]
=
None
,
tile_sample_min_width
:
Optional
[
int
]
=
None
,
tile_sample_stride_height
:
Optional
[
float
]
=
None
,
tile_sample_stride_width
:
Optional
[
float
]
=
None
,
)
->
None
:
r
"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_sample_stride_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension.
tile_sample_stride_width (`int`, *optional*):
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
artifacts produced across the width dimension.
"""
self
.
use_tiling
=
True
self
.
tile_sample_min_height
=
tile_sample_min_height
or
self
.
tile_sample_min_height
self
.
tile_sample_min_width
=
tile_sample_min_width
or
self
.
tile_sample_min_width
self
.
tile_sample_stride_height
=
tile_sample_stride_height
or
self
.
tile_sample_stride_height
self
.
tile_sample_stride_width
=
tile_sample_stride_width
or
self
.
tile_sample_stride_width
def
disable_tiling
(
self
)
->
None
:
r
"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self
.
use_tiling
=
False
def
enable_slicing
(
self
)
->
None
:
r
"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self
.
use_slicing
=
True
def
disable_slicing
(
self
)
->
None
:
r
"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self
.
use_slicing
=
False
def
clear_cache
(
self
):
def
clear_cache
(
self
):
def
_count_conv3d
(
model
):
def
_count_conv3d
(
model
):
count
=
0
count
=
0
...
@@ -746,11 +816,14 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
...
@@ -746,11 +816,14 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
self
.
_enc_conv_idx
=
[
0
]
self
.
_enc_conv_idx
=
[
0
]
self
.
_enc_feat_map
=
[
None
]
*
self
.
_enc_conv_num
self
.
_enc_feat_map
=
[
None
]
*
self
.
_enc_conv_num
def
_encode
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
_encode
(
self
,
x
:
torch
.
Tensor
):
_
,
_
,
num_frame
,
height
,
width
=
x
.
shape
if
self
.
use_tiling
and
(
width
>
self
.
tile_sample_min_width
or
height
>
self
.
tile_sample_min_height
):
return
self
.
tiled_encode
(
x
)
self
.
clear_cache
()
self
.
clear_cache
()
## cache
iter_
=
1
+
(
num_frame
-
1
)
//
4
t
=
x
.
shape
[
2
]
iter_
=
1
+
(
t
-
1
)
//
4
for
i
in
range
(
iter_
):
for
i
in
range
(
iter_
):
self
.
_enc_conv_idx
=
[
0
]
self
.
_enc_conv_idx
=
[
0
]
if
i
==
0
:
if
i
==
0
:
...
@@ -764,8 +837,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
...
@@ -764,8 +837,6 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
out
=
torch
.
cat
([
out
,
out_
],
2
)
out
=
torch
.
cat
([
out
,
out_
],
2
)
enc
=
self
.
quant_conv
(
out
)
enc
=
self
.
quant_conv
(
out
)
mu
,
logvar
=
enc
[:,
:
self
.
z_dim
,
:,
:,
:],
enc
[:,
self
.
z_dim
:,
:,
:,
:]
enc
=
torch
.
cat
([
mu
,
logvar
],
dim
=
1
)
self
.
clear_cache
()
self
.
clear_cache
()
return
enc
return
enc
...
@@ -785,18 +856,28 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
...
@@ -785,18 +856,28 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The latent representations of the encoded videos. If `return_dict` is True, a
The latent representations of the encoded videos. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
"""
h
=
self
.
_encode
(
x
)
if
self
.
use_slicing
and
x
.
shape
[
0
]
>
1
:
encoded_slices
=
[
self
.
_encode
(
x_slice
)
for
x_slice
in
x
.
split
(
1
)]
h
=
torch
.
cat
(
encoded_slices
)
else
:
h
=
self
.
_encode
(
x
)
posterior
=
DiagonalGaussianDistribution
(
h
)
posterior
=
DiagonalGaussianDistribution
(
h
)
if
not
return_dict
:
if
not
return_dict
:
return
(
posterior
,)
return
(
posterior
,)
return
AutoencoderKLOutput
(
latent_dist
=
posterior
)
return
AutoencoderKLOutput
(
latent_dist
=
posterior
)
def
_decode
(
self
,
z
:
torch
.
Tensor
,
return_dict
:
bool
=
True
)
->
Union
[
DecoderOutput
,
torch
.
Tensor
]:
def
_decode
(
self
,
z
:
torch
.
Tensor
,
return_dict
:
bool
=
True
):
self
.
clear_cache
()
_
,
_
,
num_frame
,
height
,
width
=
z
.
shape
tile_latent_min_height
=
self
.
tile_sample_min_height
//
self
.
spatial_compression_ratio
tile_latent_min_width
=
self
.
tile_sample_min_width
//
self
.
spatial_compression_ratio
if
self
.
use_tiling
and
(
width
>
tile_latent_min_width
or
height
>
tile_latent_min_height
):
return
self
.
tiled_decode
(
z
,
return_dict
=
return_dict
)
iter_
=
z
.
shape
[
2
]
self
.
clear_cache
()
x
=
self
.
post_quant_conv
(
z
)
x
=
self
.
post_quant_conv
(
z
)
for
i
in
range
(
iter_
):
for
i
in
range
(
num_frame
):
self
.
_conv_idx
=
[
0
]
self
.
_conv_idx
=
[
0
]
if
i
==
0
:
if
i
==
0
:
out
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:],
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
)
out
=
self
.
decoder
(
x
[:,
:,
i
:
i
+
1
,
:,
:],
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
)
...
@@ -826,12 +907,161 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
...
@@ -826,12 +907,161 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
returned.
"""
"""
decoded
=
self
.
_decode
(
z
).
sample
if
self
.
use_slicing
and
z
.
shape
[
0
]
>
1
:
decoded_slices
=
[
self
.
_decode
(
z_slice
).
sample
for
z_slice
in
z
.
split
(
1
)]
decoded
=
torch
.
cat
(
decoded_slices
)
else
:
decoded
=
self
.
_decode
(
z
).
sample
if
not
return_dict
:
if
not
return_dict
:
return
(
decoded
,)
return
(
decoded
,)
return
DecoderOutput
(
sample
=
decoded
)
return
DecoderOutput
(
sample
=
decoded
)
def
blend_v
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
blend_extent
:
int
)
->
torch
.
Tensor
:
blend_extent
=
min
(
a
.
shape
[
-
2
],
b
.
shape
[
-
2
],
blend_extent
)
for
y
in
range
(
blend_extent
):
b
[:,
:,
:,
y
,
:]
=
a
[:,
:,
:,
-
blend_extent
+
y
,
:]
*
(
1
-
y
/
blend_extent
)
+
b
[:,
:,
:,
y
,
:]
*
(
y
/
blend_extent
)
return
b
def
blend_h
(
self
,
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
blend_extent
:
int
)
->
torch
.
Tensor
:
blend_extent
=
min
(
a
.
shape
[
-
1
],
b
.
shape
[
-
1
],
blend_extent
)
for
x
in
range
(
blend_extent
):
b
[:,
:,
:,
:,
x
]
=
a
[:,
:,
:,
:,
-
blend_extent
+
x
]
*
(
1
-
x
/
blend_extent
)
+
b
[:,
:,
:,
:,
x
]
*
(
x
/
blend_extent
)
return
b
def
tiled_encode
(
self
,
x
:
torch
.
Tensor
)
->
AutoencoderKLOutput
:
r
"""Encode a batch of images using a tiled encoder.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
_
,
_
,
num_frames
,
height
,
width
=
x
.
shape
latent_height
=
height
//
self
.
spatial_compression_ratio
latent_width
=
width
//
self
.
spatial_compression_ratio
tile_latent_min_height
=
self
.
tile_sample_min_height
//
self
.
spatial_compression_ratio
tile_latent_min_width
=
self
.
tile_sample_min_width
//
self
.
spatial_compression_ratio
tile_latent_stride_height
=
self
.
tile_sample_stride_height
//
self
.
spatial_compression_ratio
tile_latent_stride_width
=
self
.
tile_sample_stride_width
//
self
.
spatial_compression_ratio
blend_height
=
tile_latent_min_height
-
tile_latent_stride_height
blend_width
=
tile_latent_min_width
-
tile_latent_stride_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows
=
[]
for
i
in
range
(
0
,
height
,
self
.
tile_sample_stride_height
):
row
=
[]
for
j
in
range
(
0
,
width
,
self
.
tile_sample_stride_width
):
self
.
clear_cache
()
time
=
[]
frame_range
=
1
+
(
num_frames
-
1
)
//
4
for
k
in
range
(
frame_range
):
self
.
_enc_conv_idx
=
[
0
]
if
k
==
0
:
tile
=
x
[:,
:,
:
1
,
i
:
i
+
self
.
tile_sample_min_height
,
j
:
j
+
self
.
tile_sample_min_width
]
else
:
tile
=
x
[
:,
:,
1
+
4
*
(
k
-
1
)
:
1
+
4
*
k
,
i
:
i
+
self
.
tile_sample_min_height
,
j
:
j
+
self
.
tile_sample_min_width
,
]
tile
=
self
.
encoder
(
tile
,
feat_cache
=
self
.
_enc_feat_map
,
feat_idx
=
self
.
_enc_conv_idx
)
tile
=
self
.
quant_conv
(
tile
)
time
.
append
(
tile
)
row
.
append
(
torch
.
cat
(
time
,
dim
=
2
))
rows
.
append
(
row
)
self
.
clear_cache
()
result_rows
=
[]
for
i
,
row
in
enumerate
(
rows
):
result_row
=
[]
for
j
,
tile
in
enumerate
(
row
):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if
i
>
0
:
tile
=
self
.
blend_v
(
rows
[
i
-
1
][
j
],
tile
,
blend_height
)
if
j
>
0
:
tile
=
self
.
blend_h
(
row
[
j
-
1
],
tile
,
blend_width
)
result_row
.
append
(
tile
[:,
:,
:,
:
tile_latent_stride_height
,
:
tile_latent_stride_width
])
result_rows
.
append
(
torch
.
cat
(
result_row
,
dim
=-
1
))
enc
=
torch
.
cat
(
result_rows
,
dim
=
3
)[:,
:,
:,
:
latent_height
,
:
latent_width
]
return
enc
def
tiled_decode
(
self
,
z
:
torch
.
Tensor
,
return_dict
:
bool
=
True
)
->
Union
[
DecoderOutput
,
torch
.
Tensor
]:
r
"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
_
,
_
,
num_frames
,
height
,
width
=
z
.
shape
sample_height
=
height
*
self
.
spatial_compression_ratio
sample_width
=
width
*
self
.
spatial_compression_ratio
tile_latent_min_height
=
self
.
tile_sample_min_height
//
self
.
spatial_compression_ratio
tile_latent_min_width
=
self
.
tile_sample_min_width
//
self
.
spatial_compression_ratio
tile_latent_stride_height
=
self
.
tile_sample_stride_height
//
self
.
spatial_compression_ratio
tile_latent_stride_width
=
self
.
tile_sample_stride_width
//
self
.
spatial_compression_ratio
blend_height
=
self
.
tile_sample_min_height
-
self
.
tile_sample_stride_height
blend_width
=
self
.
tile_sample_min_width
-
self
.
tile_sample_stride_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows
=
[]
for
i
in
range
(
0
,
height
,
tile_latent_stride_height
):
row
=
[]
for
j
in
range
(
0
,
width
,
tile_latent_stride_width
):
self
.
clear_cache
()
time
=
[]
for
k
in
range
(
num_frames
):
self
.
_conv_idx
=
[
0
]
tile
=
z
[:,
:,
k
:
k
+
1
,
i
:
i
+
tile_latent_min_height
,
j
:
j
+
tile_latent_min_width
]
tile
=
self
.
post_quant_conv
(
tile
)
decoded
=
self
.
decoder
(
tile
,
feat_cache
=
self
.
_feat_map
,
feat_idx
=
self
.
_conv_idx
)
time
.
append
(
decoded
)
row
.
append
(
torch
.
cat
(
time
,
dim
=
2
))
rows
.
append
(
row
)
self
.
clear_cache
()
result_rows
=
[]
for
i
,
row
in
enumerate
(
rows
):
result_row
=
[]
for
j
,
tile
in
enumerate
(
row
):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if
i
>
0
:
tile
=
self
.
blend_v
(
rows
[
i
-
1
][
j
],
tile
,
blend_height
)
if
j
>
0
:
tile
=
self
.
blend_h
(
row
[
j
-
1
],
tile
,
blend_width
)
result_row
.
append
(
tile
[:,
:,
:,
:
self
.
tile_sample_stride_height
,
:
self
.
tile_sample_stride_width
])
result_rows
.
append
(
torch
.
cat
(
result_row
,
dim
=-
1
))
dec
=
torch
.
cat
(
result_rows
,
dim
=
3
)[:,
:,
:,
:
sample_height
,
:
sample_width
]
if
not
return_dict
:
return
(
dec
,)
return
DecoderOutput
(
sample
=
dec
)
def
forward
(
def
forward
(
self
,
self
,
sample
:
torch
.
Tensor
,
sample
:
torch
.
Tensor
,
...
...
tests/models/autoencoders/test_models_autoencoder_wan.py
View file @
8520d496
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
import
unittest
import
unittest
import
torch
from
diffusers
import
AutoencoderKLWan
from
diffusers
import
AutoencoderKLWan
from
diffusers.utils.testing_utils
import
enable_full_determinism
,
floats_tensor
,
torch_device
from
diffusers.utils.testing_utils
import
enable_full_determinism
,
floats_tensor
,
torch_device
...
@@ -44,9 +46,16 @@ class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase
...
@@ -44,9 +46,16 @@ class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase
num_frames
=
9
num_frames
=
9
num_channels
=
3
num_channels
=
3
sizes
=
(
16
,
16
)
sizes
=
(
16
,
16
)
image
=
floats_tensor
((
batch_size
,
num_channels
,
num_frames
)
+
sizes
).
to
(
torch_device
)
image
=
floats_tensor
((
batch_size
,
num_channels
,
num_frames
)
+
sizes
).
to
(
torch_device
)
return
{
"sample"
:
image
}
@
property
def
dummy_input_tiling
(
self
):
batch_size
=
2
num_frames
=
9
num_channels
=
3
sizes
=
(
128
,
128
)
image
=
floats_tensor
((
batch_size
,
num_channels
,
num_frames
)
+
sizes
).
to
(
torch_device
)
return
{
"sample"
:
image
}
return
{
"sample"
:
image
}
@
property
@
property
...
@@ -62,6 +71,73 @@ class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase
...
@@ -62,6 +71,73 @@ class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase
inputs_dict
=
self
.
dummy_input
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
return
init_dict
,
inputs_dict
def
prepare_init_args_and_inputs_for_tiling
(
self
):
init_dict
=
self
.
get_autoencoder_kl_wan_config
()
inputs_dict
=
self
.
dummy_input_tiling
return
init_dict
,
inputs_dict
def
test_enable_disable_tiling
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_tiling
()
torch
.
manual_seed
(
0
)
model
=
self
.
model_class
(
**
init_dict
).
to
(
torch_device
)
inputs_dict
.
update
({
"return_dict"
:
False
})
torch
.
manual_seed
(
0
)
output_without_tiling
=
model
(
**
inputs_dict
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
torch
.
manual_seed
(
0
)
model
.
enable_tiling
(
96
,
96
,
64
,
64
)
output_with_tiling
=
model
(
**
inputs_dict
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertLess
(
(
output_without_tiling
.
detach
().
cpu
().
numpy
()
-
output_with_tiling
.
detach
().
cpu
().
numpy
()).
max
(),
0.5
,
"VAE tiling should not affect the inference results"
,
)
torch
.
manual_seed
(
0
)
model
.
disable_tiling
()
output_without_tiling_2
=
model
(
**
inputs_dict
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertEqual
(
output_without_tiling
.
detach
().
cpu
().
numpy
().
all
(),
output_without_tiling_2
.
detach
().
cpu
().
numpy
().
all
(),
"Without tiling outputs should match with the outputs when tiling is manually disabled."
,
)
def
test_enable_disable_slicing
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
torch
.
manual_seed
(
0
)
model
=
self
.
model_class
(
**
init_dict
).
to
(
torch_device
)
inputs_dict
.
update
({
"return_dict"
:
False
})
torch
.
manual_seed
(
0
)
output_without_slicing
=
model
(
**
inputs_dict
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
torch
.
manual_seed
(
0
)
model
.
enable_slicing
()
output_with_slicing
=
model
(
**
inputs_dict
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertLess
(
(
output_without_slicing
.
detach
().
cpu
().
numpy
()
-
output_with_slicing
.
detach
().
cpu
().
numpy
()).
max
(),
0.05
,
"VAE slicing should not affect the inference results"
,
)
torch
.
manual_seed
(
0
)
model
.
disable_slicing
()
output_without_slicing_2
=
model
(
**
inputs_dict
,
generator
=
torch
.
manual_seed
(
0
))[
0
]
self
.
assertEqual
(
output_without_slicing
.
detach
().
cpu
().
numpy
().
all
(),
output_without_slicing_2
.
detach
().
cpu
().
numpy
().
all
(),
"Without slicing outputs should match with the outputs when slicing is manually disabled."
,
)
@
unittest
.
skip
(
"Gradient checkpointing has not been implemented yet"
)
@
unittest
.
skip
(
"Gradient checkpointing has not been implemented yet"
)
def
test_gradient_checkpointing_is_applied
(
self
):
def
test_gradient_checkpointing_is_applied
(
self
):
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment