Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
cd5cf2bc
Unverified
Commit
cd5cf2bc
authored
Sep 15, 2022
by
Super Daniel
Committed by
GitHub
Sep 15, 2022
Browse files
[fx/tuning] tune performance on rotor with meta info. (#1599)
parent
a7cda6f5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
96 additions
and
107 deletions
+96
-107
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
+14
-69
colossalai/fx/passes/algorithms/linearize.py
colossalai/fx/passes/algorithms/linearize.py
+5
-1
colossalai/fx/profiler/__init__.py
colossalai/fx/profiler/__init__.py
+1
-1
colossalai/fx/profiler/dataflow.py
colossalai/fx/profiler/dataflow.py
+16
-12
colossalai/fx/profiler/memory.py
colossalai/fx/profiler/memory.py
+34
-2
colossalai/fx/profiler/opcount.py
colossalai/fx/profiler/opcount.py
+1
-0
colossalai/fx/profiler/profiler.py
colossalai/fx/profiler/profiler.py
+25
-22
No files found.
colossalai/fx/passes/algorithms/ckpt_solver_rotor.py
View file @
cd5cf2bc
from
typing
import
List
,
Tuple
import
torch
from
torch.fx
import
GraphModule
,
Node
from
torch.fx
import
Node
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.profiler
import
parameter_size
from
colossalai.fx.profiler
import
activation_size
,
parameter_size
import
math
from
.linearize
import
linearize
from
.utils
import
*
...
...
@@ -31,7 +30,7 @@ def _compute_table(chain: Chain, mmax) -> Tuple:
# Build table
opt
=
[[{}
for
_
in
range
(
chain
.
length
+
1
)]
for
_
in
range
(
mmax
+
1
)]
what
=
[[{}
for
_
in
range
(
chain
.
length
+
1
)]
for
_
in
range
(
mmax
+
1
)]
#
#
Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Last one is a dict because its indices go from i to l. Renumbering will wait for C implementation
# Initialize borders of the tables for lmax-lmin = 0
for
m
in
range
(
mmax
+
1
):
...
...
@@ -115,43 +114,6 @@ def _discretize(mem_unit, values):
return
[
math
.
ceil
(
value
/
mem_unit
)
for
value
in
values
]
def
_compute_size
(
obj
:
torch
.
Tensor
)
->
int
:
return
obj
.
numel
()
*
obj
.
element_size
()
def
_compute_output_size
(
node
:
List
[
Node
])
->
int
:
"""Compute the output size of a node
Args:
node (List[Node]): node, list of torch.fx.Node
Returns:
int: output size
"""
return
node
[
-
1
].
meta
[
'tensor_meta'
].
numel
*
torch
.
tensor
([],
dtype
=
node
[
-
1
].
meta
[
'tensor_meta'
].
dtype
).
element_size
()
def
_get_inplace
(
node
:
Node
)
->
bool
:
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
is_inplace
=
False
if
node
.
op
==
"call_function"
:
is_inplace
=
node
.
kwargs
.
get
(
"inplace"
,
False
)
elif
node
.
op
==
"call_module"
:
is_inplace
=
getattr
(
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
),
"inplace"
,
False
)
return
is_inplace
def
_fwd_xbar
(
node
:
List
[
Node
])
->
int
:
"""Get the forward xbar of a node
...
...
@@ -221,46 +183,33 @@ def _get_bwd_mem_tmp(node: List[Node]) -> int:
for
k
,
v
in
deps
.
items
():
if
v
>
0
:
deps_size
+=
k
.
meta
[
'bwd_mem_out'
]
if
v
==
float
(
'-inf'
):
deps_size
-=
k
.
meta
[
'fwd_mem_tmp'
]
+
k
.
meta
[
'fwd_mem_out'
]
return
deps_size
bwd_mem_tmp
=
0
deps
=
{}
# add all the users for last node into deps,
# as those nodes' gradient out will be stored in memory
for
child
in
node
[
-
1
].
users
:
deps
[
child
]
=
1
for
n
in
reversed
(
node
):
deps
[
n
]
=
len
(
n
.
all_input_nodes
)
bwd_mem_tmp
=
max
(
bwd_mem_tmp
,
_get_deps_size
()
+
n
.
meta
[
'bwd_mem_tmp'
])
deps
[
n
]
=
len
(
n
.
all_input_nodes
)
for
child
in
n
.
users
:
if
child
in
deps
:
deps
[
child
]
-=
1
for
key
in
list
(
deps
.
keys
()):
if
deps
[
key
]
==
0
:
del
deps
[
key
]
if
deps
[
child
]
<=
0
:
deps
[
child
]
=
float
(
'-inf'
)
# free
return
bwd_mem_tmp
def
_construct_chain
(
node_list
:
List
[
List
[
Node
]],
data
,
mem_unit
:
int
)
->
Chain
:
def
_construct_chain
(
node_list
:
List
[
List
[
Node
]],
input
,
mem_unit
:
int
)
->
Chain
:
fwd_time
=
[]
bwd_time
=
[]
if
isinstance
(
data
,
torch
.
Tensor
):
xbar_sizes
=
[
_compute_size
(
data
)]
x_sizes
=
[
_compute_size
(
data
)]
elif
isinstance
(
data
,
list
)
or
isinstance
(
data
,
tuple
):
xbar_sizes
=
[
sum
([
_compute_size
(
obj
)
for
obj
in
data
])]
x_sizes
=
[
sum
([
_compute_size
(
obj
)
for
obj
in
data
])]
elif
isinstance
(
data
,
dict
):
xbar_sizes
=
[
sum
([
_compute_size
(
obj
)
for
obj
in
data
.
values
()])]
x_sizes
=
[
sum
([
_compute_size
(
obj
)
for
obj
in
data
.
values
()])]
xbar_sizes
=
[
activation_size
(
input
)]
x_sizes
=
[
activation_size
(
input
)]
# currently we can't get the temp memory needed in fwd
tmp_fwd
=
[
0
]
*
len
(
node_list
)
tmp_bwd
=
[]
...
...
@@ -268,14 +217,10 @@ def _construct_chain(node_list: List[List[Node]], data, mem_unit: int) -> Chain:
for
idx
,
node
in
enumerate
(
node_list
):
fwd_time
.
append
(
_fwd_time
(
node
))
bwd_time
.
append
(
_bwd_time
(
node
))
x_sizes
.
append
(
_compute_output_size
(
node
)
)
x_sizes
.
append
(
node
[
-
1
].
meta
[
'fwd_mem_out'
]
)
xbar_sizes
.
append
(
max
(
x_sizes
[
-
1
],
_fwd_xbar
(
node
)))
tmp_bwd
.
append
(
_get_bwd_mem_tmp
(
node
))
# if a node with only one inplace op, we need to let x_bar = 0
if
len
(
node
)
==
1
and
_get_inplace
(
node
[
0
]):
xbar_sizes
[
-
1
]
=
0
bwd_time
.
append
(
0
)
# currently we view loss backward temp as zero
...
...
@@ -381,7 +326,7 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit
:
int
,
mem_slots
:
int
=
500
,
cnode
:
List
[
str
]
=
None
,
eps
:
float
=
0.0
2
)
->
ColoGraphModule
:
eps
:
float
=
0.0
)
->
ColoGraphModule
:
"""solver that automatically find activation checkpoint in rotor's manner
Args:
...
...
@@ -390,7 +335,7 @@ def solver_rotor(gm: ColoGraphModule,
mem_limit (int): memory budget in Byte.
mem_slots (int, optional): number of slots for discretizing memory budget. Defaults to 500.
cnode (List[Node], optional): common node list for linearize. Defaults to None.
eps (float): epsilon for memory decay. Defaults to 0.0
2
eps (float): epsilon for memory decay. Defaults to 0.0
Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
...
...
colossalai/fx/passes/algorithms/linearize.py
View file @
cd5cf2bc
from
typing
import
List
,
Any
from
torch.fx
import
GraphModule
,
Node
from
colossalai.fx.profiler
import
is_inplace
# Common nodes are type of nodes that could be seen as attributes and remain
# unchanged throughout the whole model, it will be used several times by
...
...
@@ -41,6 +42,9 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
Returns:
List[List[Node]]: List of list, each inside list of Node presents
the actual 'node' in linearized manner.
Remarks:
We merge the inplace ops into the previous node.
"""
def
_is_sink
()
->
bool
:
...
...
@@ -50,7 +54,7 @@ def linearize(gm: GraphModule, cnode: List[str] = None) -> List[List[Node]]:
bool
"""
return
not
sum
([
v
for
_
,
v
in
deps
.
items
()])
return
not
sum
([
v
for
_
,
v
in
deps
.
items
()])
and
not
any
(
map
(
is_inplace
,
n
.
users
))
# make sure that item in cnode is valid
if
cnode
:
...
...
colossalai/fx/profiler/__init__.py
View file @
cd5cf2bc
...
...
@@ -7,4 +7,4 @@ else:
from
.experimental
import
meta_profiler_function
,
meta_profiler_module
,
profile_function
,
profile_method
,
profile_module
from
.dataflow
import
GraphInfo
from
.memory
import
parameter_size
,
activation_size
from
.memory
import
parameter_size
,
activation_size
,
is_inplace
colossalai/fx/profiler/dataflow.py
View file @
cd5cf2bc
from
dataclasses
import
dataclass
from
enum
import
Enum
from
functools
import
partial
from
typing
import
Dict
from
torch.fx
import
Graph
,
Node
from
.memory
import
activation_size
from
.memory
import
activation_size
,
is_inplace
from
.
import
META_COMPATIBILITY
if
META_COMPATIBILITY
:
from
.memory
import
NORMALIZATION_ATEN
,
CLONE_ATEN
class
Phase
(
Enum
):
FORWARD
=
0
LOSS
=
1
BACKWARD
=
2
PLACEHOLDER
=
3
BACKWARD
=
1
PLACEHOLDER
=
2
@
dataclass
...
...
@@ -86,8 +87,10 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
def
_peak_memory
(
deps
:
Dict
[
Node
,
int
]):
peak_mem
=
0
for
k
,
v
in
deps
.
items
():
if
v
>
0
:
if
v
>
0
and
is_phase
(
k
,
Phase
.
BACKWARD
)
and
not
any
(
map
(
is_inplace
,
k
.
users
))
:
peak_mem
+=
activation_size
(
k
.
meta
[
'out'
])
if
v
<=
float
(
'-inf'
)
and
is_saved
(
k
)
and
(
k
.
target
not
in
NORMALIZATION_ATEN
):
peak_mem
-=
activation_size
(
k
.
meta
[
'out'
])
return
peak_mem
# deps is used to track all the memory dependencies of the graph.
...
...
@@ -96,7 +99,7 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
for
n
in
graph
.
nodes
:
n
:
Node
if
is_saved
(
n
)
and
not
any
(
map
(
partial
(
is_phase
,
phase
=
Phase
.
LOSS
)
,
n
.
users
)):
if
is_saved
(
n
)
and
(
n
.
target
not
in
NORMALIZATION_ATEN
)
or
any
(
map
(
lambda
x
:
x
.
target
in
CLONE_ATEN
,
n
.
users
)):
# A forward tensor who is marked `save` but is not
# an input to `loss` should be saved during forward.
# If the tensor is a placeholder, then it belongs to `fwd_mem_in`.
...
...
@@ -110,13 +113,14 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
graph_info
.
fwd_mem_tmp
+=
activation_size
(
n
.
meta
[
'out'
])
elif
is_phase
(
n
,
Phase
.
BACKWARD
):
if
len
(
n
.
users
):
# liveness analysis is only used in backward
deps
[
n
]
=
len
(
n
.
users
)
graph_info
.
bwd_mem_tmp
=
max
(
graph_info
.
bwd_mem_tmp
,
_peak_memory
(
deps
))
for
input_n
in
n
.
all_input_nodes
:
if
input_n
in
deps
:
deps
[
input_n
]
-=
1
else
:
# TODO: some of the bwd_mem_out might be model parameters.
# basically a backward node without user is a `grad_out` node
graph_info
.
bwd_mem_out
+=
activation_size
(
n
.
meta
[
'out'
])
for
input_n
in
n
.
all_input_nodes
:
if
input_n
in
deps
:
deps
[
input_n
]
-=
1
if
deps
[
input_n
]
<=
0
:
deps
[
input_n
]
=
float
(
'-inf'
)
return
graph_info
colossalai/fx/profiler/memory.py
View file @
cd5cf2bc
import
torch
from
torch.fx
import
Node
from
typing
import
Union
,
Dict
,
List
,
Tuple
from
operator
import
add
,
floordiv
,
getitem
,
mul
,
neg
,
setitem
,
sub
,
pos
from
.
import
META_COMPATIBILITY
__all__
=
[
'activation_size'
,
'parameter_size'
]
__all__
=
[
'activation_size'
,
'parameter_size'
,
'is_inplace'
]
if
META_COMPATIBILITY
:
aten
=
torch
.
ops
.
aten
...
...
@@ -21,6 +22,7 @@ if META_COMPATIBILITY:
aten
.
bernoulli_
.
float
,
# inplace reshaping
aten
.
copy_
.
default
,
aten
.
detach
.
default
,
aten
.
t
.
default
,
aten
.
transpose
.
int
,
...
...
@@ -28,7 +30,17 @@ if META_COMPATIBILITY:
aten
.
_unsafe_view
.
default
,
]
__all__
+=
[
'INPLACE_ATEN'
,
'WEIRD_OPS'
]
NORMALIZATION_ATEN
=
[
aten
.
native_batch_norm
.
default
,
aten
.
native_layer_norm
.
default
,
# aten.max_pool2d_with_indices.default,
]
CLONE_ATEN
=
[
aten
.
clone
.
default
,
]
__all__
+=
[
'INPLACE_ATEN'
,
'WEIRD_OPS'
,
'NORMALIZATION_ATEN'
,
'CLONE_ATEN'
]
else
:
# TODO fill out the inplace ops
...
...
@@ -106,3 +118,23 @@ def parameter_size(mod: torch.nn.Module) -> int:
for
param
in
mod
.
parameters
():
param_size
+=
param
.
numel
()
*
torch
.
tensor
([],
dtype
=
param
.
dtype
).
element_size
()
return
param_size
def
is_inplace
(
n
:
Node
):
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
inplace
=
False
if
n
.
op
==
"call_function"
:
inplace
=
n
.
kwargs
.
get
(
"inplace"
,
False
)
if
META_COMPATIBILITY
and
n
.
target
in
INPLACE_ATEN
:
inplace
=
True
elif
n
.
op
==
"call_module"
:
inplace
=
getattr
(
n
.
graph
.
owning_module
.
get_submodule
(
n
.
target
),
"inplace"
,
False
)
return
inplace
colossalai/fx/profiler/opcount.py
View file @
cd5cf2bc
...
...
@@ -222,6 +222,7 @@ flop_mapping = {
aten
.
_adaptive_avg_pool2d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
_adaptive_avg_pool3d
.
default
:
elementwise_flop_counter
(
1
,
0
),
aten
.
_adaptive_avg_pool3d_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
aten
.
embedding_dense_backward
.
default
:
elementwise_flop_counter
(
0
,
1
),
}
elementwise_flop_aten
=
[
...
...
colossalai/fx/profiler/profiler.py
View file @
cd5cf2bc
from
dataclasses
import
dataclass
from
enum
import
auto
from
typing
import
Callable
,
Any
,
Dict
,
Tuple
import
torch
from
torch.fx
import
Graph
,
Node
from
torch.fx.node
import
Argument
,
Target
from
torch.utils._pytree
import
tree_map
from
.dataflow
import
GraphInfo
,
autograd_graph_analysis
,
Phase
from
.memory
import
WEIRD_OPS
,
activation_size
from
.memory
import
WEIRD_OPS
from
.tensor
import
MetaTensor
from
.opcount
import
flop_mapping
...
...
@@ -23,7 +21,7 @@ def is_autogradable(x):
return
isinstance
(
x
,
torch
.
Tensor
)
and
x
.
is_floating_point
()
def
_profile
(
target
:
Callable
,
*
args
,
inplace
=
False
,
**
kwargs
)
->
Tuple
[
Any
,
...]:
def
_profile
(
target
:
Callable
,
*
args
,
**
kwargs
)
->
Tuple
[
Any
,
...]:
"""
Profile a Callable function with args and kwargs.
...
...
@@ -42,7 +40,6 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
# `flop_count`` serves as a global dictionary to store results.
flop_count
=
{
Phase
.
FORWARD
:
0
,
Phase
.
LOSS
:
0
,
Phase
.
BACKWARD
:
0
,
}
...
...
@@ -71,6 +68,10 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
kwargs_node
=
tree_map
(
get_node
,
kwargs
)
node
=
subgraph
.
create_node
(
'call_function'
,
func
,
args_node
,
kwargs_node
)
# do not allocate on `cpu`
if
'device'
in
kwargs
:
kwargs
[
'device'
]
=
'meta'
def
unwrap
(
x
):
# if x is a `nn.Parameter`, we can first wrap it with `FlopTensor`
if
isinstance
(
x
,
torch
.
Tensor
)
and
not
hasattr
(
x
,
'_tensor'
):
...
...
@@ -101,13 +102,13 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
if
target
not
in
WEIRD_OPS
:
def
wrap
(
x
):
return
FlopTensor
(
x
.
detach
().
requires_grad_
(
True
))
if
is_autogradable
(
x
)
and
not
inplace
and
not
hasattr
(
x
,
'_tensor'
)
else
x
return
FlopTensor
(
x
.
detach
().
requires_grad_
(
True
))
if
is_autogradable
(
x
)
and
not
hasattr
(
x
,
'_tensor'
)
else
x
else
:
def
wrap
(
x
):
return
FlopTensor
(
x
.
detach
().
requires_grad_
(
False
))
if
is_autogradable
(
x
)
and
not
inplace
and
not
hasattr
(
x
,
'_tensor'
)
else
x
return
FlopTensor
(
x
.
detach
().
requires_grad_
(
False
))
if
is_autogradable
(
x
)
and
not
hasattr
(
x
,
'_tensor'
)
else
x
# Basically, we need to detach the args and kwargs from the outer graph.
args
=
tree_map
(
wrap
,
args
)
...
...
@@ -125,7 +126,7 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
tree_map
(
set_placeholder
,
kwargs
)
def
pack
(
x
):
if
isinstance
(
x
,
FlopTensor
):
if
isinstance
(
x
,
FlopTensor
)
and
not
isinstance
(
x
,
torch
.
nn
.
Parameter
)
:
x
.
_node
.
meta
[
'saved'
]
=
True
return
x
...
...
@@ -143,13 +144,15 @@ def _profile(target: Callable, *args, inplace=False, **kwargs) -> Tuple[Any, ...
else
:
out
=
target
(
*
args
,
**
kwargs
)
# If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node.
if
is_autogradable
(
out
)
and
out
.
requires_grad
:
phase
=
Phase
.
LOSS
loss
=
out
.
sum
()
phase
=
Phase
.
BACKWARD
loss
.
backward
()
# If the output is not a floating point `torch.Tensor` or it does not
# requires grad, then we should not run backward for this node.
if
is_autogradable
(
out
)
and
out
.
requires_grad
:
phase
=
Phase
.
BACKWARD
if
isinstance
(
out
,
FlopTensor
):
out
.
_node
.
meta
[
'save'
]
=
False
grad
=
torch
.
empty_like
(
out
.
_tensor
,
device
=
'meta'
)
if
isinstance
(
out
,
FlopTensor
)
else
torch
.
empty_like
(
out
,
device
=
'meta'
)
torch
.
autograd
.
backward
(
out
,
FlopTensor
(
grad
))
graph_info
=
autograd_graph_analysis
(
subgraph
)
graph_info
.
fwd_flop
,
graph_info
.
bwd_flop
=
flop_count
[
Phase
.
FORWARD
],
flop_count
[
Phase
.
BACKWARD
]
...
...
@@ -172,7 +175,7 @@ def profile_function(target: 'Target') -> Callable:
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> output, meta_info = profile_function(func)(input
, inplace=False
)
>>> output, meta_info = profile_function(func)(input)
"""
def
f
(
*
args
:
Tuple
[
Argument
,
...],
**
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
...
...
@@ -183,7 +186,7 @@ def profile_function(target: 'Target') -> Callable:
args
=
tree_map
(
lambda
x
:
x
.
to
(
'meta'
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
,
args
)
kwargs
=
tree_map
(
lambda
x
:
x
.
to
(
'meta'
)
if
isinstance
(
x
,
torch
.
Tensor
)
else
x
,
kwargs
)
out
=
func
(
*
args
,
**
kwargs
)
return
out
,
GraphInfo
(
out
.
numel
(),
out
.
numel
(),
activation_size
((
args
,
kwargs
)),
0
,
activation_size
(
out
)
,
0
)
return
out
,
GraphInfo
(
out
.
numel
(),
out
.
numel
(),
0
,
0
,
0
,
0
)
out
,
meta
=
_profile
(
func
,
*
args
,
**
kwargs
)
return
out
,
meta
...
...
@@ -201,7 +204,7 @@ def profile_method(target: 'Target') -> Callable:
def
f
(
*
args
:
Tuple
[
Argument
,
...],
**
kwargs
:
Dict
[
str
,
Any
])
->
Any
:
# execute the method and return the result
assert
isinstance
(
target
,
str
),
f
'
{
target
}
instance is not str.'
out
,
meta
=
_profile
(
target
,
*
args
,
inplace
=
False
,
**
kwargs
)
out
,
meta
=
_profile
(
target
,
*
args
,
**
kwargs
)
return
out
,
meta
return
f
...
...
@@ -230,8 +233,8 @@ def profile_module(module: torch.nn.Module) -> Callable:
args
=
tree_map
(
lambda
x
:
x
.
to
(
'meta'
),
args
)
kwargs
=
tree_map
(
lambda
x
:
x
.
to
(
'meta'
),
kwargs
)
out
=
func
(
*
args
,
**
kwargs
)
return
out
,
GraphInfo
(
out
.
numel
(),
out
.
numel
(),
activation_size
((
args
,
kwargs
)),
0
,
activation_size
(
out
)
,
0
)
out
,
meta
=
_profile
(
func
,
*
args
,
inplace
=
getattr
(
module
,
'inplace'
,
False
),
**
kwargs
)
return
out
,
GraphInfo
(
out
.
numel
(),
out
.
numel
(),
0
,
0
,
0
,
0
)
out
,
meta
=
_profile
(
func
,
*
args
,
**
kwargs
)
return
out
,
meta
f
.
__name__
=
module
.
__class__
.
__name__
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment