Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b5b4a398
Unverified
Commit
b5b4a398
authored
Apr 26, 2024
by
SangBin Cho
Committed by
GitHub
Apr 25, 2024
Browse files
[Mypy] Typing lora folder (#4337)
parent
f4bc4de1
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
91 additions
and
70 deletions
+91
-70
.github/workflows/mypy.yaml
.github/workflows/mypy.yaml
+3
-4
format.sh
format.sh
+1
-1
vllm/lora/layers.py
vllm/lora/layers.py
+22
-13
vllm/lora/lora.py
vllm/lora/lora.py
+17
-11
vllm/lora/models.py
vllm/lora/models.py
+34
-30
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+12
-9
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+2
-2
No files found.
.github/workflows/mypy.yaml
View file @
b5b4a398
...
...
@@ -33,8 +33,6 @@ jobs:
-
name
:
Mypy
run
:
|
mypy vllm/attention --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
...
...
@@ -44,8 +42,9 @@ jobs:
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --config-file pyproject.toml
# TODO(sang): Fix nested dir
# mypy vllm/lora/*.py --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
format.sh
View file @
b5b4a398
...
...
@@ -106,7 +106,7 @@ mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker
--config-file
pyproject.toml
mypy vllm/spec_decode
--config-file
pyproject.toml
mypy vllm/model_executor/
*
.py
--config-file
pyproject.toml
#
mypy vllm/lora
/*.py
--config-file pyproject.toml
mypy vllm/lora
--config-file
pyproject.toml
CODESPELL_EXCLUDES
=(
...
...
vllm/lora/layers.py
View file @
b5b4a398
...
...
@@ -176,6 +176,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def
__init__
(
self
,
base_layer
:
VocabParallelEmbedding
)
->
None
:
super
().
__init__
()
self
.
base_layer
=
base_layer
self
.
embeddings_slice
:
Optional
[
Tuple
[
int
,
int
]]
self
.
embeddings_weights
:
Optional
[
torch
.
Tensor
]
def
create_lora_weights
(
self
,
...
...
@@ -233,9 +235,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self
.
lora_a_stacked
.
shape
[
0
]
*
self
.
lora_a_stacked
.
shape
[
1
],
self
.
lora_a_stacked
.
shape
[
2
],
)
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
indices_len
:
Optional
[
List
[
int
]]
=
None
self
.
embeddings_indices
=
None
# Lazily initialized.
self
.
indices
:
torch
.
Tensor
self
.
indices_len
:
List
[
int
]
self
.
embeddings_indices
:
torch
.
Tensor
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
index
]
=
0
...
...
@@ -267,6 +270,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self
.
embeddings_tensors
.
shape
[
1
],
self
.
embeddings_tensors
.
shape
[
2
]
)[
self
.
embeddings_slice
[
0
]:
self
.
embeddings_slice
[
1
]]
assert
self
.
embeddings_weights
is
not
None
self
.
embeddings_weights
[:
embeddings
.
shape
[
0
]].
copy_
(
embeddings
)
def
set_mapping
(
...
...
@@ -343,11 +347,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
indices_len
:
Optional
[
List
[
int
]]
=
None
self
.
output_dim
=
self
.
lora_b_stacked
.
shape
[
2
]
# lazily initialized.
self
.
indices
:
torch
.
Tensor
self
.
indices_len
:
List
[
int
]
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
index
]
=
0
self
.
lora_b_stacked
[
index
]
=
0
...
...
@@ -475,8 +480,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
device
=
self
.
device
,
)
for
_
in
range
(
n_slices
))
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
output_dim
=
self
.
lora_b_stacked
[
0
].
shape
[
2
]
# Lazily initialized.
self
.
indices
:
torch
.
Tensor
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
0
][
index
]
=
0
...
...
@@ -690,7 +696,8 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
kv_proj_shard_size
)
self
.
packed_indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
standard_indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
indices_len
:
Optional
[
List
[
int
]]
=
None
# lazily initialized.
self
.
indices_len
:
List
[
int
]
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
0
][
index
]
=
0
...
...
@@ -814,8 +821,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
)
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
indices_len
:
Optional
[
List
[
int
]]
=
None
# Lazily initialized
self
.
indices
:
torch
.
Tensor
self
.
indices_len
:
List
[
int
]
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
index
]
=
0
...
...
@@ -991,9 +999,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
dtype
=
self
.
dtype
,
device
=
self
.
device
,
)
self
.
indices
=
None
self
.
indices_padded
=
None
self
.
indices_len
=
None
# Lazily initialized.
self
.
indices
:
torch
.
Tensor
self
.
indices_len
:
List
[
int
]
self
.
indices_padded
:
torch
.
Tensor
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
index
]
=
0
...
...
vllm/lora/lora.py
View file @
b5b4a398
...
...
@@ -97,9 +97,9 @@ class PackedLoRALayerWeights(LoRALayerWeights):
self
,
module_name
:
str
,
rank
:
int
,
lora_alphas
:
List
[
int
],
lora_a
:
List
[
torch
.
Tensor
],
lora_b
:
List
[
torch
.
Tensor
],
lora_alphas
:
List
[
Optional
[
int
]
]
,
lora_a
:
List
[
Optional
[
torch
.
Tensor
]
]
,
lora_b
:
List
[
Optional
[
torch
.
Tensor
]
]
,
scaling
:
Optional
[
List
[
float
]]
=
None
,
)
->
None
:
super
().
__init__
(
...
...
@@ -108,17 +108,20 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alpha
=
0
,
lora_a
=
lora_a
,
lora_b
=
lora_b
,
scaling
=
scaling
,
scaling
=
scaling
,
# type: ignore
embeddings_tensor
=
None
,
)
self
.
lora_alphas
=
lora_alphas
if
scaling
is
None
:
self
.
scaling
=
[
lora_alpha
/
self
.
rank
for
lora_alpha
in
self
.
lora_alphas
self
.
scaling
=
[
# type: ignore
lora_alpha
/
self
.
rank
# type: ignore # noqa
for
lora_alpha
in
self
.
lora_alphas
]
@
classmethod
def
pack
(
cls
,
loras
:
List
[
"LoRALayerWeights"
])
->
"PackedLoRALayerWeights"
:
def
pack
(
cls
,
loras
:
List
[
Optional
[
"LoRALayerWeights"
]]
)
->
"PackedLoRALayerWeights"
:
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
...
...
@@ -136,16 +139,19 @@ class PackedLoRALayerWeights(LoRALayerWeights):
[
lora
.
lora_alpha
if
lora
is
not
None
else
None
for
lora
in
loras
],
[
lora
.
lora_a
if
lora
is
not
None
else
None
for
lora
in
loras
],
[
lora
.
lora_b
if
lora
is
not
None
else
None
for
lora
in
loras
],
scaling
=
[
1
if
lora
is
not
None
else
None
for
lora
in
loras
])
scaling
=
[
1
if
lora
is
not
None
else
None
# type: ignore
for
lora
in
loras
])
return
obj
def
optimize
(
self
)
->
"PackedLoRALayerWeights"
:
"""Optimize the LoRA by merging the scaling into lora_b."""
for
i
in
range
(
len
(
self
.
lora_b
)):
if
self
.
scaling
[
i
]
==
1
or
self
.
lora_b
[
i
]
is
None
:
if
self
.
scaling
[
i
]
==
1
or
self
.
lora_b
[
i
]
is
None
:
# type: ignore
continue
self
.
lora_b
[
i
]
*=
self
.
scaling
[
i
]
self
.
scaling
[
i
]
=
1
self
.
lora_b
[
i
]
*=
self
.
scaling
[
i
]
# type: ignore
self
.
scaling
[
i
]
=
1
# type: ignore
return
self
@
property
...
...
vllm/lora/models.py
View file @
b5b4a398
...
...
@@ -3,7 +3,7 @@ import json
import
math
import
os
import
re
from
typing
import
Callable
,
Dict
,
Hashable
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
safetensors.torch
import
torch
...
...
@@ -53,44 +53,46 @@ def convert_mapping(
embeddings.
indices_len: List of lengths of the above tensors.
"""
ind
ices
=
list
(
mapping
.
index_mapping
).
copy
()
embedding_indices
=
indices
.
copy
()
lora_indices
=
indices
.
copy
()
prompt_mapping
=
[
ind
ex_mapping_indices
:
List
[
int
]
=
list
(
mapping
.
index_mapping
).
copy
()
embedding_indices
=
index_mapping_
indices
.
copy
()
lora_indices
=
index_mapping_
indices
.
copy
()
prompt_mapping
:
List
[
int
]
=
[
lora_index_to_id
.
index
(
x
)
if
x
>
0
else
-
1
for
x
in
mapping
.
prompt_mapping
]
lora_idx
=
None
for
i
in
range
(
len
(
indices
)):
for
i
in
range
(
len
(
index_mapping_
indices
)):
# TODO index can be slow. optimize
lora_idx
=
(
lora_index_to_id
.
index
(
indices
[
i
])
if
indices
[
i
]
>
0
else
-
1
)
embedding_indices
[
i
]
=
lora_idx
if
indices
[
i
]
>
0
else
0
indices
[
i
]
=
i
lora_idx
=
(
lora_index_to_id
.
index
(
index_mapping_
indices
[
i
])
if
index_mapping_
indices
[
i
]
>
0
else
-
1
)
embedding_indices
[
i
]
=
lora_idx
if
index_mapping_
indices
[
i
]
>
0
else
0
index_mapping_
indices
[
i
]
=
i
lora_indices
[
i
]
=
lora_idx
indices
=
torch
.
tensor
([
indices
,
lora_indices
,
embedding_indices
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
prompt_mapping
=
torch
.
tensor
(
prompt_mapping
,
device
=
"cuda"
,
dtype
=
torch
.
long
)
indices
=
torch
.
tensor
(
[
index_mapping_indices
,
lora_indices
,
embedding_indices
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
prompt_mapping_tensor
=
torch
.
tensor
(
prompt_mapping
,
device
=
"cuda"
,
dtype
=
torch
.
long
)
embeddings_indices
=
torch
.
stack
([
indices
[
2
]
*
extra_vocab_size
,
indices
[
2
]
*
(
vocab_size
+
extra_vocab_size
)
])
embeddings_indices
[
embeddings_indices
==
-
1
]
=
max_loras
-
1
base_indices
=
indices
[
1
]
sampler_indices
=
prompt_mapping
sampler_indices
=
prompt_mapping
_tensor
sampler_indices_padded
=
sampler_indices
.
clone
()
sampler_indices_padded
[
sampler_indices_padded
==
-
1
]
=
max_loras
-
1
sampler_indices_padded
=
(
torch
.
arange
(
0
,
len
(
sampler_indices_padded
),
device
=
"cuda"
,
dtype
=
torch
.
long
)
+
(
sampler_indices_padded
*
len
(
sampler_indices_padded
)))
indices_len
=
(
base_indices
.
shape
[
-
1
],
sampler_indices
.
shape
[
-
1
],
sampler_indices_padded
.
shape
[
-
1
],
embeddings_indices
.
shape
[
-
1
])
indices_len
=
[
base_indices
.
shape
[
-
1
],
sampler_indices
.
shape
[
-
1
],
sampler_indices_padded
.
shape
[
-
1
],
embeddings_indices
.
shape
[
-
1
]
]
return
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
embeddings_indices
,
indices_len
)
...
...
@@ -149,6 +151,7 @@ class LoRAModel:
if
module_name
not
in
loras
:
lora_embeddings_tensor
=
None
if
embeddings
:
assert
embedding_modules
is
not
None
embeddings_module
=
next
(
(
k
for
k
in
embedding_modules
if
k
in
module_name
),
None
)
...
...
@@ -171,6 +174,7 @@ class LoRAModel:
else
:
loras
[
module_name
].
lora_b
=
tensor
.
to
(
device
=
device
,
dtype
=
dtype
).
t
()
assert
embedding_padding_modules
is
not
None
if
any
(
name
in
module_name
for
name
in
embedding_padding_modules
)
and
target_embedding_padding
is
not
None
:
...
...
@@ -295,11 +299,10 @@ class LoRAModelManager:
self
.
max_num_batched_tokens
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
self
.
offsets
=
[]
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self
.
indices_len
=
[
None
]
*
4
self
.
indices_len
:
List
[
Optional
[
int
]]
=
[
None
]
*
4
self
.
model
:
nn
.
Module
=
model
if
hasattr
(
self
.
model
,
"supported_lora_modules"
):
...
...
@@ -312,7 +315,7 @@ class LoRAModelManager:
self
.
_registered_loras
:
Dict
[
int
,
LoRAModel
]
=
{}
# Dict instead of a Set for compatibility with LRUCache.
self
.
_active_loras
:
Dict
[
int
,
None
]
=
{}
self
.
_last_mapping
=
None
self
.
_last_mapping
:
Optional
[
LoRAMapping
]
=
None
self
.
_create_lora_modules
()
self
.
model
.
lora_manager
=
self
...
...
@@ -370,7 +373,7 @@ class LoRAModelManager:
return
True
return
False
def
_add_lora
(
self
,
lora
:
LoRAModel
)
->
bool
:
def
_add_lora
(
self
,
lora
:
LoRAModel
):
self
.
_create_merged_loras_inplace
(
lora
)
self
.
_registered_loras
[
lora
.
id
]
=
lora
...
...
@@ -418,7 +421,7 @@ class LoRAModelManager:
def
get_lora
(
self
,
lora_id
:
int
)
->
Optional
[
LoRAModel
]:
return
self
.
_registered_loras
.
get
(
lora_id
,
None
)
def
remove_all_loras
(
self
)
->
bool
:
def
remove_all_loras
(
self
):
"""Remove all LoRAModels from the manager."""
self
.
_registered_loras
.
clear
()
self
.
lora_index_to_id
=
[
None
]
*
self
.
lora_slots
...
...
@@ -467,6 +470,7 @@ class LoRAModelManager:
continue
parts
=
module_name
.
split
(
"."
)
if
module_name
not
in
self
.
packed_modules
:
assert
embedding_modules
is
not
None
if
parts
[
-
1
]
in
embedding_modules
:
input_dim
=
(
module
.
base_layer
.
org_vocab_size
+
self
.
lora_config
.
lora_extra_vocab_size
if
...
...
@@ -500,7 +504,7 @@ class LoRAModelManager:
else
:
parts
=
module_name
.
split
(
"."
)
replacements
=
self
.
packed_modules_mapping
[
parts
[
-
1
]]
subloras
=
[]
subloras
:
List
[
Optional
[
"LoRALayerWeights"
]]
=
[]
for
i
,
r
in
enumerate
(
replacements
):
lora
=
LoRALayerWeights
.
create_dummy_lora_weights
(
module_name
+
"."
+
r
,
...
...
@@ -538,7 +542,7 @@ class LoRAModelManager:
def
_create_merged_loras_inplace
(
self
,
lora_model
:
LoRAModel
)
->
None
:
for
module_name
,
new_module_names
in
self
.
packed_modules
.
items
():
replacement_loras
=
[]
replacement_loras
:
List
[
Optional
[
LoRALayerWeights
]]
=
[]
has_replacement
=
False
for
r
in
new_module_names
:
lora
=
lora_model
.
get_lora
(
r
)
...
...
@@ -557,12 +561,12 @@ class LoRAModelManager:
class
LoRALRUCache
(
LRUCache
[
LoRAModel
]):
def
__init__
(
self
,
capacity
:
int
,
deactivate_lora_fn
:
Callable
[[
Hashable
],
None
]):
def
__init__
(
self
,
capacity
:
int
,
deactivate_lora_fn
:
Callable
[[
int
],
bool
]):
super
().
__init__
(
capacity
)
self
.
deactivate_lora_fn
=
deactivate_lora_fn
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
LoRAModel
):
def
_on_remove
(
self
,
key
:
int
,
value
:
LoRAModel
):
logger
.
debug
(
f
"Removing LoRA. int id:
{
key
}
"
)
self
.
deactivate_lora_fn
(
key
)
return
super
().
_on_remove
(
key
,
value
)
...
...
vllm/lora/worker_manager.py
View file @
b5b4a398
from
abc
import
ABC
,
abstractmethod
,
abstractproperty
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Type
from
typing
import
Any
,
Dict
,
List
,
Set
,
Type
import
torch
...
...
@@ -37,7 +37,7 @@ class AbstractWorkerLoRAManager(ABC):
...
@
abstractmethod
def
set_active_loras
(
self
,
lora_requests
:
Lis
t
[
LoRARequest
],
def
set_active_loras
(
self
,
lora_requests
:
Se
t
[
LoRARequest
],
lora_mapping
:
LoRAMapping
)
->
None
:
...
...
...
@@ -54,7 +54,7 @@ class AbstractWorkerLoRAManager(ABC):
...
@
abstractmethod
def
remove_all_loras
(
self
)
->
bool
:
def
remove_all_loras
(
self
):
...
@
abstractmethod
...
...
@@ -81,10 +81,11 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
embedding_padding_modules
:
List
[
str
],
lora_model_cls
:
Type
[
LoRAModel
]
=
LoRAModel
,
):
self
.
_lora_manager
:
Optional
[
LoRAModelManager
]
=
None
self
.
_lora_model_cls
=
lora_model_cls
self
.
embedding_modules
=
embedding_modules
self
.
embedding_padding_modules
=
embedding_padding_modules
# Lazily initialized by create_lora_manager.
self
.
_lora_manager
:
LoRAModelManager
super
().
__init__
(
max_num_seqs
,
max_num_batched_tokens
,
vocab_size
,
lora_config
,
device
)
...
...
@@ -104,7 +105,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
lora_config
=
self
.
lora_config
,
lora_manager_cls
=
self
.
_lora_manager_cls
,
)
self
.
_lora_manager
:
LoRAModelManager
=
lora_manager
self
.
_lora_manager
=
lora_manager
return
lora_manager
.
model
def
set_active_loras
(
self
,
lora_requests
:
Set
[
LoRARequest
],
...
...
@@ -188,7 +189,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
_lora_manager
.
remove_lora
(
lora_id
)
def
remove_all_loras
(
self
)
->
bool
:
def
remove_all_loras
(
self
):
self
.
_lora_manager
.
remove_all_loras
()
def
list_loras
(
self
)
->
Set
[
int
]:
...
...
@@ -217,10 +218,10 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
lora_config
=
self
.
lora_config
,
max_num_batched_tokens
=
self
.
max_num_batched_tokens
,
)
self
.
_lora_manager
:
LRUCacheLoRAModelManager
=
lora_manager
self
.
_lora_manager
=
lora_manager
return
lora_manager
.
model
def
_apply_loras
(
self
,
lora_requests
:
Lis
t
[
LoRARequest
])
->
None
:
def
_apply_loras
(
self
,
lora_requests
:
Se
t
[
LoRARequest
])
->
None
:
loras_map
=
{
lora_request
.
lora_int_id
:
lora_request
for
lora_request
in
lora_requests
if
lora_request
...
...
@@ -237,12 +238,14 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
if
lora_request
.
lora_int_id
not
in
self
.
list_loras
():
# Remove before we load the new lora to save memory
if
len
(
self
.
_lora_manager
)
+
1
>
self
.
_lora_manager
.
capacity
:
assert
isinstance
(
self
.
_lora_manager
,
LRUCacheLoRAModelManager
)
self
.
_lora_manager
.
remove_oldest_lora
()
lora
=
self
.
_load_lora
(
lora_request
)
loaded
=
self
.
_lora_manager
.
add_lora
(
lora
)
else
:
# If the lora is already loaded, just touch it to
# update its position in the caches
loaded
=
self
.
_lora_manager
.
get_lora
(
lora_request
.
lora_int_id
)
loaded
=
self
.
_lora_manager
.
get_lora
(
lora_request
.
lora_int_id
)
is
not
None
self
.
_lora_manager
.
activate_lora
(
lora_request
.
lora_int_id
)
return
loaded
vllm/worker/model_runner.py
View file @
b5b4a398
...
...
@@ -928,10 +928,10 @@ class ModelRunner:
torch
.
cuda
.
synchronize
()
return
def
remove_all_loras
(
self
)
->
bool
:
def
remove_all_loras
(
self
):
if
not
self
.
lora_manager
:
raise
RuntimeError
(
"LoRA is not enabled."
)
return
self
.
lora_manager
.
remove_all_loras
()
self
.
lora_manager
.
remove_all_loras
()
def
set_active_loras
(
self
,
lora_requests
:
Set
[
LoRARequest
],
lora_mapping
:
LoRAMapping
)
->
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment