Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
cfd6ec74
Unverified
Commit
cfd6ec74
authored
Aug 06, 2025
by
Aryan
Committed by
GitHub
Aug 06, 2025
Browse files
[refactor] condense group offloading (#11990)
* update * update * refactor * add test * address review comment * nit
parent
1082c46a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
161 additions
and
102 deletions
+161
-102
src/diffusers/hooks/group_offloading.py
src/diffusers/hooks/group_offloading.py
+74
-102
tests/hooks/test_group_offloading.py
tests/hooks/test_group_offloading.py
+87
-0
No files found.
src/diffusers/hooks/group_offloading.py
View file @
cfd6ec74
...
...
@@ -95,7 +95,7 @@ class ModuleGroup:
self
.
offload_to_disk_path
=
offload_to_disk_path
self
.
_is_offloaded_to_disk
=
False
if
self
.
offload_to_disk_path
:
if
self
.
offload_to_disk_path
is
not
None
:
# Instead of `group_id or str(id(self))` we do this because `group_id` can be "" as well.
self
.
group_id
=
group_id
if
group_id
is
not
None
else
str
(
id
(
self
))
short_hash
=
_compute_group_hash
(
self
.
group_id
)
...
...
@@ -115,6 +115,12 @@ class ModuleGroup:
else
:
self
.
cpu_param_dict
=
self
.
_init_cpu_param_dict
()
self
.
_torch_accelerator_module
=
(
getattr
(
torch
,
torch
.
accelerator
.
current_accelerator
().
type
)
if
hasattr
(
torch
,
"accelerator"
)
else
torch
.
cuda
)
def
_init_cpu_param_dict
(
self
):
cpu_param_dict
=
{}
if
self
.
stream
is
None
:
...
...
@@ -138,112 +144,76 @@ class ModuleGroup:
@
contextmanager
def
_pinned_memory_tensors
(
self
):
pinned_dict
=
{}
try
:
for
param
,
tensor
in
self
.
cpu_param_dict
.
items
():
if
not
tensor
.
is_pinned
():
pinned_dict
[
param
]
=
tensor
.
pin_memory
()
else
:
pinned_dict
[
param
]
=
tensor
pinned_dict
=
{
param
:
tensor
.
pin_memory
()
if
not
tensor
.
is_pinned
()
else
tensor
for
param
,
tensor
in
self
.
cpu_param_dict
.
items
()
}
yield
pinned_dict
finally
:
pinned_dict
=
None
def
_transfer_tensor_to_device
(
self
,
tensor
,
source_tensor
,
current_stream
=
None
):
def
_transfer_tensor_to_device
(
self
,
tensor
,
source_tensor
):
tensor
.
data
=
source_tensor
.
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
if
self
.
record_stream
and
current_stream
is
not
None
:
tensor
.
data
.
record_stream
(
current_stream
)
if
self
.
record_stream
:
tensor
.
data
.
record_stream
(
self
.
_torch_accelerator_module
.
current_stream
()
)
def
_process_tensors_from_modules
(
self
,
pinned_memory
=
None
,
current_stream
=
None
):
def
_process_tensors_from_modules
(
self
,
pinned_memory
=
None
):
for
group_module
in
self
.
modules
:
for
param
in
group_module
.
parameters
():
source
=
pinned_memory
[
param
]
if
pinned_memory
else
param
.
data
self
.
_transfer_tensor_to_device
(
param
,
source
,
current_stream
)
self
.
_transfer_tensor_to_device
(
param
,
source
)
for
buffer
in
group_module
.
buffers
():
source
=
pinned_memory
[
buffer
]
if
pinned_memory
else
buffer
.
data
self
.
_transfer_tensor_to_device
(
buffer
,
source
,
current_stream
)
self
.
_transfer_tensor_to_device
(
buffer
,
source
)
for
param
in
self
.
parameters
:
source
=
pinned_memory
[
param
]
if
pinned_memory
else
param
.
data
self
.
_transfer_tensor_to_device
(
param
,
source
,
current_stream
)
self
.
_transfer_tensor_to_device
(
param
,
source
)
for
buffer
in
self
.
buffers
:
source
=
pinned_memory
[
buffer
]
if
pinned_memory
else
buffer
.
data
self
.
_transfer_tensor_to_device
(
buffer
,
source
,
current_stream
)
self
.
_transfer_tensor_to_device
(
buffer
,
source
)
def
_onload_from_disk
(
self
,
current_stream
):
def
_onload_from_disk
(
self
):
if
self
.
stream
is
not
None
:
loaded_cpu_tensors
=
safetensors
.
torch
.
load_file
(
self
.
safetensors_file_path
,
device
=
"cpu"
)
for
key
,
tensor_obj
in
self
.
key_to_tensor
.
items
():
self
.
cpu_param_dict
[
tensor_obj
]
=
loaded_cpu_tensors
[
key
]
with
self
.
_pinned_memory_tensors
()
as
pinned_memory
:
for
key
,
tensor_obj
in
self
.
key_to_tensor
.
items
():
self
.
_transfer_tensor_to_device
(
tensor_obj
,
pinned_memory
[
tensor_obj
],
current_stream
)
self
.
cpu_param_dict
.
clear
()
# Wait for previous Host->Device transfer to complete
self
.
stream
.
synchronize
()
else
:
onload_device
=
(
self
.
onload_device
.
type
if
isinstance
(
self
.
onload_device
,
torch
.
device
)
else
self
.
onload_device
)
loaded_tensors
=
safetensors
.
torch
.
load_file
(
self
.
safetensors_file_path
,
device
=
onload_device
)
for
key
,
tensor_obj
in
self
.
key_to_tensor
.
items
():
tensor_obj
.
data
=
loaded_tensors
[
key
]
context
=
nullcontext
()
if
self
.
stream
is
None
else
self
.
_torch_accelerator_module
.
stream
(
self
.
stream
)
current_stream
=
self
.
_torch_accelerator_module
.
current_stream
()
if
self
.
record_stream
else
None
def
_onload_from_memory
(
self
,
current_stream
):
if
self
.
stream
is
not
None
:
with
self
.
_pinned_memory_tensors
()
as
pinned_memory
:
self
.
_process_tensors_from_modules
(
pinned_memory
,
current_stream
)
else
:
self
.
_process_tensors_from_modules
(
None
,
current_stream
)
@
torch
.
compiler
.
disable
()
def
onload_
(
self
):
torch_accelerator_module
=
(
getattr
(
torch
,
torch
.
accelerator
.
current_accelerator
().
type
)
if
hasattr
(
torch
,
"accelerator"
)
else
torch
.
cuda
)
context
=
nullcontext
()
if
self
.
stream
is
None
else
torch_accelerator_module
.
stream
(
self
.
stream
)
current_stream
=
torch_accelerator_module
.
current_stream
()
if
self
.
record_stream
else
None
with
context
:
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
device
=
str
(
self
.
onload_device
)
if
self
.
stream
is
None
else
"cpu"
loaded_tensors
=
safetensors
.
torch
.
load_file
(
self
.
safetensors_file_path
,
device
=
device
)
if
self
.
offload_to_disk_path
:
if
self
.
stream
is
not
None
:
# Wait for previous Host->Device transfer to complete
self
.
stream
.
synchronize
()
with
context
:
if
self
.
stream
is
not
None
:
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors
=
safetensors
.
torch
.
load_file
(
self
.
safetensors_file_path
,
device
=
"cpu"
)
for
key
,
tensor_obj
in
self
.
key_to_tensor
.
items
():
pinned_tensor
=
loaded_cpu_tensors
[
key
].
pin_memory
()
tensor_obj
.
data
=
pinned_tensor
.
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
if
self
.
record_stream
:
tensor_obj
.
data
.
record_stream
(
current_stream
)
else
:
# Load directly to the target device (synchronous)
onload_device
=
(
self
.
onload_device
.
type
if
isinstance
(
self
.
onload_device
,
torch
.
device
)
else
self
.
onload_device
)
loaded_tensors
=
safetensors
.
torch
.
load_file
(
self
.
safetensors_file_path
,
device
=
onload_device
)
for
key
,
tensor_obj
in
self
.
key_to_tensor
.
items
():
tensor_obj
.
data
=
loaded_tensors
[
key
]
return
for
key
,
tensor_obj
in
self
.
key_to_tensor
.
items
():
pinned_tensor
=
loaded_tensors
[
key
].
pin_memory
()
tensor_obj
.
data
=
pinned_tensor
.
to
(
self
.
onload_device
,
non_blocking
=
self
.
non_blocking
)
if
self
.
record_stream
:
tensor_obj
.
data
.
record_stream
(
current_stream
)
else
:
onload_device
=
(
self
.
onload_device
.
type
if
isinstance
(
self
.
onload_device
,
torch
.
device
)
else
self
.
onload_device
)
loaded_tensors
=
safetensors
.
torch
.
load_file
(
self
.
safetensors_file_path
,
device
=
onload_device
)
for
key
,
tensor_obj
in
self
.
key_to_tensor
.
items
():
tensor_obj
.
data
=
loaded_tensors
[
key
]
def
_onload_from_memory
(
self
):
if
self
.
stream
is
not
None
:
# Wait for previous Host->Device transfer to complete
self
.
stream
.
synchronize
()
context
=
nullcontext
()
if
self
.
stream
is
None
else
self
.
_torch_accelerator_module
.
stream
(
self
.
stream
)
with
context
:
if
self
.
offload_to_disk_path
:
self
.
_onload_from_disk
(
current_stream
)
if
self
.
stream
is
not
None
:
with
self
.
_pinned_memory_tensors
()
as
pinned_memory
:
self
.
_process_tensors_from_modules
(
pinned_memory
)
else
:
self
.
_
onload_from_memory
(
current_stream
)
self
.
_
process_tensors_from_modules
(
None
)
def
_offload_to_disk
(
self
):
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
...
...
@@ -264,14 +234,10 @@ class ModuleGroup:
tensor_obj
.
data
=
torch
.
empty_like
(
tensor_obj
.
data
,
device
=
self
.
offload_device
)
def
_offload_to_memory
(
self
):
torch_accelerator_module
=
(
getattr
(
torch
,
torch
.
accelerator
.
current_accelerator
().
type
)
if
hasattr
(
torch
,
"accelerator"
)
else
torch
.
cuda
)
if
self
.
stream
is
not
None
:
if
not
self
.
record_stream
:
torch_accelerator_module
.
current_stream
().
synchronize
()
self
.
_torch_accelerator_module
.
current_stream
().
synchronize
()
for
group_module
in
self
.
modules
:
for
param
in
group_module
.
parameters
():
param
.
data
=
self
.
cpu_param_dict
[
param
]
...
...
@@ -282,15 +248,23 @@ class ModuleGroup:
else
:
for
group_module
in
self
.
modules
:
group_module
.
to
(
self
.
offload_device
,
non_blocking
=
self
.
non_blocking
)
group_module
.
to
(
self
.
offload_device
,
non_blocking
=
False
)
for
param
in
self
.
parameters
:
param
.
data
=
param
.
data
.
to
(
self
.
offload_device
,
non_blocking
=
self
.
non_blocking
)
param
.
data
=
param
.
data
.
to
(
self
.
offload_device
,
non_blocking
=
False
)
for
buffer
in
self
.
buffers
:
buffer
.
data
=
buffer
.
data
.
to
(
self
.
offload_device
,
non_blocking
=
self
.
non_blocking
)
buffer
.
data
=
buffer
.
data
.
to
(
self
.
offload_device
,
non_blocking
=
False
)
@
torch
.
compiler
.
disable
()
def
onload_
(
self
):
r
"""Onloads the group of parameters to the onload_device."""
if
self
.
offload_to_disk_path
is
not
None
:
self
.
_onload_from_disk
()
else
:
self
.
_onload_from_memory
()
@
torch
.
compiler
.
disable
()
def
offload_
(
self
):
r
"""Offloads the group of
module
s to the offload_device."""
r
"""Offloads the group of
parameter
s to the offload_device."""
if
self
.
offload_to_disk_path
:
self
.
_offload_to_disk
()
else
:
...
...
@@ -307,11 +281,9 @@ class GroupOffloadingHook(ModelHook):
_is_stateful
=
False
def
__init__
(
self
,
group
:
ModuleGroup
,
next_group
:
Optional
[
ModuleGroup
]
=
None
,
*
,
config
:
GroupOffloadingConfig
)
->
None
:
def
__init__
(
self
,
group
:
ModuleGroup
,
*
,
config
:
GroupOffloadingConfig
)
->
None
:
self
.
group
=
group
self
.
next_group
=
next_group
self
.
next_group
:
Optional
[
ModuleGroup
]
=
None
self
.
config
=
config
def
initialize_hook
(
self
,
module
:
torch
.
nn
.
Module
)
->
torch
.
nn
.
Module
:
...
...
@@ -459,8 +431,8 @@ class LayerExecutionTrackerHook(ModelHook):
def
apply_group_offloading
(
module
:
torch
.
nn
.
Module
,
onload_device
:
torch
.
device
,
offload_device
:
torch
.
device
=
torch
.
device
(
"cpu"
),
onload_device
:
Union
[
str
,
torch
.
device
]
,
offload_device
:
Union
[
str
,
torch
.
device
]
=
torch
.
device
(
"cpu"
),
offload_type
:
Union
[
str
,
GroupOffloadingType
]
=
"block_level"
,
num_blocks_per_group
:
Optional
[
int
]
=
None
,
non_blocking
:
bool
=
False
,
...
...
@@ -546,6 +518,8 @@ def apply_group_offloading(
```
"""
onload_device
=
torch
.
device
(
onload_device
)
if
isinstance
(
onload_device
,
str
)
else
onload_device
offload_device
=
torch
.
device
(
offload_device
)
if
isinstance
(
offload_device
,
str
)
else
offload_device
offload_type
=
GroupOffloadingType
(
offload_type
)
stream
=
None
...
...
@@ -633,7 +607,7 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
# Apply group offloading hooks to the module groups
for
i
,
group
in
enumerate
(
matched_module_groups
):
for
group_module
in
group
.
modules
:
_apply_group_offloading_hook
(
group_module
,
group
,
None
,
config
=
config
)
_apply_group_offloading_hook
(
group_module
,
group
,
config
=
config
)
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
# when the forward pass of this module is called. This is because the top-level module is not
...
...
@@ -662,9 +636,9 @@ def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOf
group_id
=
f
"
{
module
.
__class__
.
__name__
}
_unmatched_group"
,
)
if
config
.
stream
is
None
:
_apply_group_offloading_hook
(
module
,
unmatched_group
,
None
,
config
=
config
)
_apply_group_offloading_hook
(
module
,
unmatched_group
,
config
=
config
)
else
:
_apply_lazy_group_offloading_hook
(
module
,
unmatched_group
,
None
,
config
=
config
)
_apply_lazy_group_offloading_hook
(
module
,
unmatched_group
,
config
=
config
)
def
_apply_group_offloading_leaf_level
(
module
:
torch
.
nn
.
Module
,
config
:
GroupOffloadingConfig
)
->
None
:
...
...
@@ -693,7 +667,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self
=
True
,
group_id
=
name
,
)
_apply_group_offloading_hook
(
submodule
,
group
,
None
,
config
=
config
)
_apply_group_offloading_hook
(
submodule
,
group
,
config
=
config
)
modules_with_group_offloading
.
add
(
name
)
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
...
...
@@ -740,7 +714,7 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self
=
True
,
group_id
=
name
,
)
_apply_group_offloading_hook
(
parent_module
,
group
,
None
,
config
=
config
)
_apply_group_offloading_hook
(
parent_module
,
group
,
config
=
config
)
if
config
.
stream
is
not
None
:
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
...
...
@@ -762,13 +736,12 @@ def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOff
onload_self
=
True
,
group_id
=
_GROUP_ID_LAZY_LEAF
,
)
_apply_lazy_group_offloading_hook
(
module
,
unmatched_group
,
None
,
config
=
config
)
_apply_lazy_group_offloading_hook
(
module
,
unmatched_group
,
config
=
config
)
def
_apply_group_offloading_hook
(
module
:
torch
.
nn
.
Module
,
group
:
ModuleGroup
,
next_group
:
Optional
[
ModuleGroup
]
=
None
,
*
,
config
:
GroupOffloadingConfig
,
)
->
None
:
...
...
@@ -777,14 +750,13 @@ def _apply_group_offloading_hook(
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if
registry
.
get_hook
(
_GROUP_OFFLOADING
)
is
None
:
hook
=
GroupOffloadingHook
(
group
,
next_group
,
config
=
config
)
hook
=
GroupOffloadingHook
(
group
,
config
=
config
)
registry
.
register_hook
(
hook
,
_GROUP_OFFLOADING
)
def
_apply_lazy_group_offloading_hook
(
module
:
torch
.
nn
.
Module
,
group
:
ModuleGroup
,
next_group
:
Optional
[
ModuleGroup
]
=
None
,
*
,
config
:
GroupOffloadingConfig
,
)
->
None
:
...
...
@@ -793,7 +765,7 @@ def _apply_lazy_group_offloading_hook(
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
if
registry
.
get_hook
(
_GROUP_OFFLOADING
)
is
None
:
hook
=
GroupOffloadingHook
(
group
,
next_group
,
config
=
config
)
hook
=
GroupOffloadingHook
(
group
,
config
=
config
)
registry
.
register_hook
(
hook
,
_GROUP_OFFLOADING
)
lazy_prefetch_hook
=
LazyPrefetchGroupOffloadingHook
()
...
...
tests/hooks/test_group_offloading.py
View file @
cfd6ec74
...
...
@@ -17,7 +17,9 @@ import gc
import
unittest
import
torch
from
parameterized
import
parameterized
from
diffusers.hooks
import
HookRegistry
,
ModelHook
from
diffusers.models
import
ModelMixin
from
diffusers.pipelines.pipeline_utils
import
DiffusionPipeline
from
diffusers.utils
import
get_logger
...
...
@@ -99,6 +101,29 @@ class DummyModelWithMultipleBlocks(ModelMixin):
return
x
# Test for https://github.com/huggingface/diffusers/pull/12077
class
DummyModelWithLayerNorm
(
ModelMixin
):
def
__init__
(
self
,
in_features
:
int
,
hidden_features
:
int
,
out_features
:
int
,
num_layers
:
int
)
->
None
:
super
().
__init__
()
self
.
linear_1
=
torch
.
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
activation
=
torch
.
nn
.
ReLU
()
self
.
blocks
=
torch
.
nn
.
ModuleList
(
[
DummyBlock
(
hidden_features
,
hidden_features
,
hidden_features
)
for
_
in
range
(
num_layers
)]
)
self
.
layer_norm
=
torch
.
nn
.
LayerNorm
(
hidden_features
,
elementwise_affine
=
True
)
self
.
linear_2
=
torch
.
nn
.
Linear
(
hidden_features
,
out_features
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
x
=
self
.
linear_1
(
x
)
x
=
self
.
activation
(
x
)
for
block
in
self
.
blocks
:
x
=
block
(
x
)
x
=
self
.
layer_norm
(
x
)
x
=
self
.
linear_2
(
x
)
return
x
class
DummyPipeline
(
DiffusionPipeline
):
model_cpu_offload_seq
=
"model"
...
...
@@ -113,6 +138,16 @@ class DummyPipeline(DiffusionPipeline):
return
x
class
LayerOutputTrackerHook
(
ModelHook
):
def
__init__
(
self
):
super
().
__init__
()
self
.
outputs
=
[]
def
post_forward
(
self
,
module
,
output
):
self
.
outputs
.
append
(
output
)
return
output
@
require_torch_accelerator
class
GroupOffloadTests
(
unittest
.
TestCase
):
in_features
=
64
...
...
@@ -258,6 +293,7 @@ class GroupOffloadTests(unittest.TestCase):
def
test_block_level_stream_with_invocation_order_different_from_initialization_order
(
self
):
if
torch
.
device
(
torch_device
).
type
not
in
[
"cuda"
,
"xpu"
]:
return
model
=
DummyModelWithMultipleBlocks
(
in_features
=
self
.
in_features
,
hidden_features
=
self
.
hidden_features
,
...
...
@@ -274,3 +310,54 @@ class GroupOffloadTests(unittest.TestCase):
with
context
:
model
(
self
.
input
)
@
parameterized
.
expand
([(
"block_level"
,),
(
"leaf_level"
,)])
def
test_block_level_offloading_with_parameter_only_module_group
(
self
,
offload_type
:
str
):
if
torch
.
device
(
torch_device
).
type
not
in
[
"cuda"
,
"xpu"
]:
return
def
apply_layer_output_tracker_hook
(
model
:
DummyModelWithLayerNorm
):
for
name
,
module
in
model
.
named_modules
():
registry
=
HookRegistry
.
check_if_exists_or_initialize
(
module
)
hook
=
LayerOutputTrackerHook
()
registry
.
register_hook
(
hook
,
"layer_output_tracker"
)
model_ref
=
DummyModelWithLayerNorm
(
128
,
256
,
128
,
2
)
model
=
DummyModelWithLayerNorm
(
128
,
256
,
128
,
2
)
model
.
load_state_dict
(
model_ref
.
state_dict
(),
strict
=
True
)
model_ref
.
to
(
torch_device
)
model
.
enable_group_offload
(
torch_device
,
offload_type
=
offload_type
,
num_blocks_per_group
=
1
,
use_stream
=
True
)
apply_layer_output_tracker_hook
(
model_ref
)
apply_layer_output_tracker_hook
(
model
)
x
=
torch
.
randn
(
2
,
128
).
to
(
torch_device
)
out_ref
=
model_ref
(
x
)
out
=
model
(
x
)
self
.
assertTrue
(
torch
.
allclose
(
out_ref
,
out
,
atol
=
1e-5
),
"Outputs do not match."
)
num_repeats
=
4
for
i
in
range
(
num_repeats
):
out_ref
=
model_ref
(
x
)
out
=
model
(
x
)
self
.
assertTrue
(
torch
.
allclose
(
out_ref
,
out
,
atol
=
1e-5
),
"Outputs do not match after multiple invocations."
)
for
(
ref_name
,
ref_module
),
(
name
,
module
)
in
zip
(
model_ref
.
named_modules
(),
model
.
named_modules
()):
assert
ref_name
==
name
ref_outputs
=
(
HookRegistry
.
check_if_exists_or_initialize
(
ref_module
).
get_hook
(
"layer_output_tracker"
).
outputs
)
outputs
=
HookRegistry
.
check_if_exists_or_initialize
(
module
).
get_hook
(
"layer_output_tracker"
).
outputs
cumulated_absmax
=
0.0
for
i
in
range
(
len
(
outputs
)):
diff
=
ref_outputs
[
0
]
-
outputs
[
i
]
absdiff
=
diff
.
abs
()
absmax
=
absdiff
.
max
().
item
()
cumulated_absmax
+=
absmax
self
.
assertLess
(
cumulated_absmax
,
1e-5
,
f
"Output differences for
{
name
}
exceeded threshold:
{
cumulated_absmax
:.
5
f
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment