Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2df2c85b
Unverified
Commit
2df2c85b
authored
Apr 06, 2026
by
Andreas Karatzas
Committed by
GitHub
Apr 07, 2026
Browse files
[Kernels][MoE] Fix legacy_routing to use bitmatrix-based routing path (#38504)
Signed-off-by:
Andreas Karatzas
<
akaratza@amd.com
>
parent
62095e82
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
84 additions
and
216 deletions
+84
-216
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-baseline.yaml
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-baseline.yaml
+1
-1
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml
..._oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml
+2
-2
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml
...oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml
+1
-1
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml
..._oss/configs/gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml
+1
-1
tests/kernels/moe/test_gpt_oss_triton_kernels.py
tests/kernels/moe/test_gpt_oss_triton_kernels.py
+12
-57
tests/kernels/quantization/test_mxfp4_triton_ep.py
tests/kernels/quantization/test_mxfp4_triton_ep.py
+18
-41
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
...l_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+49
-113
No files found.
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-baseline.yaml
View file @
2df2c85b
...
...
@@ -3,4 +3,4 @@
model_name
:
openai/gpt-oss-20b
metric_threshold
:
0.568
reasoning_effort
:
low
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN"
\ No newline at end of file
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN
--tensor-parallel-size
2"
\ No newline at end of file
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-aiter.yaml
View file @
2df2c85b
...
...
@@ -3,6 +3,6 @@
model_name
:
amd/gpt-oss-20b-w-mxfp4-a-bf16
metric_threshold
:
0.568
reasoning_effort
:
low
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN
--moe-backend
aiter"
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN
--moe-backend
aiter
--tokenizer
openai/gpt-oss-20b
--tensor-parallel-size
2
"
env
:
VLLM_ROCM_USE_AITER
:
"
1"
VLLM_ROCM_USE_AITER
:
"
1"
\ No newline at end of file
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-bf16-triton.yaml
View file @
2df2c85b
...
...
@@ -3,4 +3,4 @@
model_name
:
amd/gpt-oss-20b-w-mxfp4-a-bf16
metric_threshold
:
0.568
reasoning_effort
:
low
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN
--moe-backend
triton"
\ No newline at end of file
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN
--moe-backend
triton
--tokenizer
openai/gpt-oss-20b
--tensor-parallel-size
2"
\ No newline at end of file
tests/evals/gpt_oss/configs/gpt-oss-20b-rocm-quark-mxfp4-fp8-triton.yaml
View file @
2df2c85b
...
...
@@ -3,6 +3,6 @@
model_name
:
amd/gpt-oss-20b-MoE-Quant-W-MXFP4-A-FP8-KV-FP8
metric_threshold
:
0.568
reasoning_effort
:
low
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN"
server_args
:
"
--attention-backend
ROCM_AITER_UNIFIED_ATTN
--tensor-parallel-size
2
"
env
:
VLLM_ROCM_USE_AITER
:
"
1"
\ No newline at end of file
tests/kernels/moe/test_gpt_oss_triton_kernels.py
View file @
2df2c85b
...
...
@@ -23,16 +23,12 @@ from triton_kernels.numerics_details.mxfp import downcast_to_mxfp, upcast_from_m
from
triton_kernels.tensor
import
FP4
,
convert_layout
,
wrap_torch_tensor
from
triton_kernels.tensor_details
import
layout
from
triton_kernels.testing
import
assert_close
from
triton_kernels.topk
import
topk
as
topk_fn
from
vllm.model_executor.layers.fused_moe.config
import
mxfp4_w4a16_moe_quant_config
from
vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe
import
(
legacy_routing
,
make_routing_data
,
triton_kernel_moe_forward
,
)
from
vllm.utils.math_utils
import
round_up
from
vllm.utils.torch_utils
import
set_random_seed
from
.utils
import
shuffle_weight
...
...
@@ -97,10 +93,18 @@ def init_compute_data(M, K, N, E, a_dtype: str, w_dtype: str, num_warps: int):
if
w_dtype
!=
"mx4"
:
pytest
.
skip
(
"NYI"
)
else
:
# quantize to mx4
# careful on the padding here, the activation padding need to be
# multiple of 64, the actual engine is not implemented
w1_bottom_pad
=
round_up
(
w1_tri
.
shape
[
1
],
64
)
-
w1_tri
.
shape
[
1
]
w1_right_pad
=
round_up
(
w1_tri
.
shape
[
2
],
128
)
-
w1_tri
.
shape
[
2
]
# Padding alignment depends on the platform. On CDNA4 the scale
# swizzle requires SCALE_K % 8 == 0 (K % 256) and
# SCALE_N % 32 == 0 (2*N % 512), matching the production
# alignment in mxfp4_round_up_hidden_size_and_intermediate_size.
# On CUDA (Hopper) the scale layout pads internally, so the
# original 64/128 alignment is sufficient.
if
current_platform
.
is_rocm
():
k_align
,
n2_align
=
256
,
512
else
:
k_align
,
n2_align
=
64
,
128
w1_bottom_pad
=
round_up
(
w1_tri
.
shape
[
1
],
k_align
)
-
w1_tri
.
shape
[
1
]
w1_right_pad
=
round_up
(
w1_tri
.
shape
[
2
],
n2_align
)
-
w1_tri
.
shape
[
2
]
w2_bottom_pad
=
w1_right_pad
//
2
w2_right_pad
=
w1_bottom_pad
...
...
@@ -367,52 +371,3 @@ def test_unit_shuffle():
)
assert_close
(
ref
=
out_ref
,
tri
=
out
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
2
,
8
,
64
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
32
,
128
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
"renormalize"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
def
test_legacy_routing
(
num_tokens
:
int
,
num_experts
:
int
,
topk
:
int
,
renormalize
:
bool
,
dtype
:
torch
.
dtype
):
set_random_seed
(
0
)
gating_output
=
torch
.
randn
(
num_tokens
,
num_experts
,
device
=
"cuda"
,
dtype
=
dtype
)
sm_first
=
not
renormalize
logits
=
gating_output
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
topk_result
=
topk_fn
(
logits
,
topk
,
apply_softmax
=
not
sm_first
)
# topk_fn returns SparseMatrix on NVIDIA, plain tuple on ROCm.
if
isinstance
(
topk_result
,
tuple
):
topk_weights
,
topk_ids_raw
,
bitmatrix
=
topk_result
from
triton_kernels.routing
import
routing_from_bitmatrix
routing_data_ref
,
gather_indx_ref
,
scatter_indx_ref
=
routing_from_bitmatrix
(
bitmatrix
,
topk_weights
,
topk_ids_raw
,
num_experts
,
topk
)
else
:
topk_ids
=
topk_result
.
indx
.
to
(
torch
.
long
)
topk_weights
=
topk_result
.
vals
routing_data_ref
,
gather_indx_ref
,
scatter_indx_ref
=
make_routing_data
(
topk_ids
,
topk_weights
,
num_experts
)
routing_data
,
gather_indx
,
scatter_indx
=
legacy_routing
(
gating_output
,
topk
,
sm_first
=
sm_first
)
assert_close
(
ref
=
gather_indx_ref
.
src_indx
,
tri
=
gather_indx
.
src_indx
,
maxtol
=
0
,
rmstol
=
0
)
assert_close
(
ref
=
gather_indx_ref
.
dst_indx
,
tri
=
gather_indx
.
dst_indx
,
maxtol
=
0
,
rmstol
=
0
)
assert_close
(
ref
=
scatter_indx_ref
.
src_indx
,
tri
=
scatter_indx
.
src_indx
,
maxtol
=
0
,
rmstol
=
0
)
assert_close
(
ref
=
scatter_indx_ref
.
dst_indx
,
tri
=
scatter_indx
.
dst_indx
,
maxtol
=
0
,
rmstol
=
0
)
tests/kernels/quantization/test_mxfp4_triton_ep.py
View file @
2df2c85b
...
...
@@ -4,12 +4,9 @@
Tests that triton_kernel_moe_forward correctly applies expert_map
remapping when expert parallelism (EP) is enabled.
Previously, legacy_routing was always used and it produced routing data
with global expert IDs that didn't correspond to local weight indices,
causing illegal memory access with EP. The fix splits routing: when
expert_map is provided, topk selection is performed first, expert_map is
applied to remap global→local IDs, and make_routing_data builds routing
structures from the local IDs.
Both EP and non-EP paths use topk + make_routing_data. When expert_map
is provided, global expert IDs are remapped to local IDs before building
routing structures.
"""
from
unittest.mock
import
MagicMock
,
patch
...
...
@@ -24,21 +21,15 @@ class TestTritonMoeForwardExpertMap:
@
pytest
.
mark
.
parametrize
(
"expert_map_present"
,
[
False
,
True
])
def
test_routing_path_selection
(
self
,
expert_map_present
):
"""Verify that th
e
EP
-aware routing path is taken when expert_map
is present, and the legacy_routing path is taken otherwise
."""
"""Verify that
bo
th EP
and non-EP paths use topk + make_routing_data,
and that expert_map remapping is applied when present
."""
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
# This is a structural test: we mock the routing functions to
# verify the correct path is exercised.
mock_expert_map
=
(
torch
.
tensor
([
0
,
-
1
,
1
,
-
1
],
device
=
device
)
if
expert_map_present
else
None
)
with
(
patch
(
"vllm.model_executor.layers.fused_moe."
"gpt_oss_triton_kernels_moe.legacy_routing"
)
as
mock_legacy
,
patch
(
"triton_kernels.topk.topk"
)
as
mock_topk
,
patch
(
"vllm.model_executor.layers.fused_moe."
...
...
@@ -53,27 +44,19 @@ class TestTritonMoeForwardExpertMap:
triton_kernel_moe_forward
,
)
# Set up return values
mock_routing_data
=
MagicMock
()
mock_gather
=
MagicMock
()
mock_scatter
=
MagicMock
()
if
expert_map_present
:
sparse_result
=
MagicMock
()
sparse_result
.
indx
=
torch
.
tensor
([[
0
,
2
]],
dtype
=
torch
.
int32
)
sparse_result
.
vals
=
torch
.
tensor
([[
0.6
,
0.4
]])
mock_topk
.
return_value
=
sparse_result
mock_make_routing
.
return_value
=
(
mock_routing_data
,
mock_gather
,
mock_scatter
,
)
else
:
mock_legacy
.
return_value
=
(
mock_routing_data
,
mock_gather
,
mock_scatter
,
)
sparse_result
=
MagicMock
()
sparse_result
.
indx
=
torch
.
tensor
([[
0
,
2
]],
dtype
=
torch
.
int32
)
sparse_result
.
vals
=
torch
.
tensor
([[
0.6
,
0.4
]])
mock_topk
.
return_value
=
sparse_result
mock_make_routing
.
return_value
=
(
mock_routing_data
,
mock_gather
,
mock_scatter
,
)
mock_fused_experts
.
return_value
=
torch
.
zeros
((
1
,
8
),
device
=
device
)
...
...
@@ -92,20 +75,14 @@ class TestTritonMoeForwardExpertMap:
expert_map
=
mock_expert_map
,
)
# Both paths use topk + make_routing_data
mock_topk
.
assert_called_once
()
mock_make_routing
.
assert_called_once
()
if
expert_map_present
:
# EP path: should use topk + make_routing_data, NOT
# legacy_routing
mock_topk
.
assert_called_once
()
mock_make_routing
.
assert_called_once
()
mock_legacy
.
assert_not_called
()
# expert_map should be None in the fused_experts call
# (already applied)
call_kwargs
=
mock_fused_experts
.
call_args
assert
call_kwargs
[
1
].
get
(
"expert_map"
)
is
None
or
(
len
(
call_kwargs
[
0
])
>
0
)
else
:
# Non-EP path: should use legacy_routing
mock_legacy
.
assert_called_once
()
mock_topk
.
assert_not_called
()
mock_make_routing
.
assert_not_called
()
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
View file @
2df2c85b
...
...
@@ -47,7 +47,6 @@ if has_triton_kernels():
BIT
,
Bitmatrix
,
)
from
triton_kernels.topk
import
topk
try
:
from
triton_kernels.tensor
import
(
...
...
@@ -89,6 +88,7 @@ def pack_bitmatrix(
offsets
=
offsets_m
[:,
None
]
*
n_expts_act
+
offsets_k
[
None
,
:]
mask
=
(
offsets_m
<
n_rows
)[:,
None
]
&
(
offsets_k
<
n_expts_act
)[
None
,
:]
indices
=
tl
.
load
(
topk_ids
+
offsets
,
mask
=
mask
,
other
=-
1
)
valid
=
indices
>=
0
div
=
indices
//
32
rem
=
indices
%
32
one
=
tl
.
cast
(
1
,
tl
.
uint32
)
...
...
@@ -99,8 +99,13 @@ def pack_bitmatrix(
offs
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
//
32
)
+
i
*
(
BLOCK_SIZE_K
//
32
)
# All topks that need to go into this column has the correct bit set.
# Other bits are 0. x is a 2D tensor.
# Guard with `valid` to prevent negative indices from producing
# spurious bits (on HIP, -1 // 32 == 0 and 1 << (-1 % 32) sets
# bit 31).
x
=
tl
.
where
(
div
[:,
:,
None
]
==
offs
[
None
,
None
,
:],
(
one
<<
rem
)[:,
:,
None
],
0
valid
[:,
:,
None
]
&
(
div
[:,
:,
None
]
==
offs
[
None
,
None
,
:]),
(
one
<<
rem
)[:,
:,
None
],
0
,
)
# Reduce x to get a single int32_t bitpack.
y
=
tl
.
reduce_or
(
x
,
axis
=
1
)
...
...
@@ -108,93 +113,6 @@ def pack_bitmatrix(
tl
.
store
(
bitmatrix_ptrs
,
y
,
mask
=
offsets_m
[:,
None
]
<
n_rows
)
def
legacy_routing_from_bitmatrix
(
bitmatrix
:
"Bitmatrix"
,
expt_scal
:
torch
.
Tensor
,
expt_indx
:
torch
.
Tensor
,
n_expts_tot
:
int
,
n_expts_act
:
int
,
)
->
tuple
[
"RoutingData"
,
"GatherIndx"
,
"ScatterIndx"
]:
"""
Replacement for the removed triton_kernels.routing.routing_from_bitmatrix.
Creates routing data from a bitmatrix representation.
"""
if
use_legacy_triton_kernels
:
from
triton_kernels.routing
import
routing_from_bitmatrix
return
routing_from_bitmatrix
(
bitmatrix
,
expt_scal
,
expt_indx
,
n_expts_tot
,
n_expts_act
)
sparse_logits
=
SparseMatrix
(
indx
=
expt_indx
,
vals
=
expt_scal
,
mask
=
bitmatrix
)
dispatch_indx
=
sparse_logits
.
mask_metadata
.
row_sorted_indx
combine_indx
=
sparse_logits
.
mask_metadata
.
col_sorted_indx
ragged_batch_metadata
=
make_ragged_tensor_metadata
(
sparse_logits
.
mask_metadata
.
col_sum
,
dispatch_indx
.
shape
[
0
],
)
gate_scal
=
sparse_logits
.
vals
.
flatten
()[
combine_indx
]
routing_data
=
RoutingData
(
gate_scal
,
ragged_batch_metadata
.
block_sizes
,
n_expts_tot
,
n_expts_act
,
ragged_batch_metadata
,
)
gather_idx
=
GatherIndx
(
combine_indx
,
dispatch_indx
)
scatter_idx
=
ScatterIndx
(
dispatch_indx
,
combine_indx
)
return
routing_data
,
gather_idx
,
scatter_idx
def
legacy_routing_from_sparsematrix
(
sparse_logits
:
"SparseMatrix"
,
n_expts_tot
:
int
,
n_expts_act
:
int
,
)
->
tuple
[
"RoutingData"
,
"GatherIndx"
,
"ScatterIndx"
]:
"""
Creates routing data from a SparseMatrix representation.
"""
dispatch_indx
=
sparse_logits
.
mask_metadata
.
row_sorted_indx
combine_indx
=
sparse_logits
.
mask_metadata
.
col_sorted_indx
ragged_batch_metadata
=
make_ragged_tensor_metadata
(
sparse_logits
.
mask_metadata
.
col_sum
,
dispatch_indx
.
shape
[
0
],
)
gate_scal
=
sparse_logits
.
vals
.
flatten
()[
combine_indx
]
routing_data
=
RoutingData
(
gate_scal
,
ragged_batch_metadata
.
block_sizes
,
n_expts_tot
,
n_expts_act
,
ragged_batch_metadata
,
)
gather_idx
=
GatherIndx
(
combine_indx
,
dispatch_indx
)
scatter_idx
=
ScatterIndx
(
dispatch_indx
,
combine_indx
)
return
routing_data
,
gather_idx
,
scatter_idx
def
legacy_routing
(
logits
:
torch
.
Tensor
,
n_expts_act
:
int
,
sm_first
:
bool
=
False
,
)
->
tuple
[
"RoutingData"
,
"GatherIndx"
,
"ScatterIndx"
]:
"""
Replacement for the removed triton_kernels.routing.routing function.
Computes routing data from gating logits.
"""
if
use_legacy_triton_kernels
:
from
triton_kernels.routing
import
routing
return
routing
(
logits
,
n_expts_act
,
sm_first
=
sm_first
)
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
sparse_logits
=
topk
(
logits
,
n_expts_act
,
apply_softmax
=
not
sm_first
)
return
legacy_routing_from_sparsematrix
(
sparse_logits
,
logits
.
shape
[
-
1
],
n_expts_act
,
)
def
triton_kernel_moe_forward
(
hidden_states
:
torch
.
Tensor
,
w1
,
# Tensor or triton_kernels.Tensor
...
...
@@ -241,26 +159,22 @@ def triton_kernel_moe_forward(
unpadded_K_w2
=
unpadded_K_w2
,
)
from
triton_kernels.topk
import
topk
as
topk_fn
sm_first
=
not
renormalize
logits
=
gating_output
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
topk_result
=
topk_fn
(
logits
,
topk
,
apply_softmax
=
not
sm_first
)
# topk may return a tuple (vals, indx, bitmatrix) or a
# SparseMatrix depending on the triton_kernels version.
if
isinstance
(
topk_result
,
tuple
):
topk_weights
,
topk_ids_raw
,
_
=
topk_result
else
:
topk_weights
=
topk_result
.
vals
topk_ids_raw
=
topk_result
.
indx
if
expert_map
is
not
None
:
# With expert parallelism, legacy_routing produces routing data
# using global expert IDs which don't correspond to local weight
# indices. Split the routing into topk selection + expert_map
# remapping + local routing data construction (matching the
# approach used by OAITritonExperts.apply).
from
triton_kernels.topk
import
topk
as
topk_fn
sm_first
=
not
renormalize
logits
=
gating_output
if
sm_first
:
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
topk_result
=
topk_fn
(
logits
,
topk
,
apply_softmax
=
not
sm_first
)
# topk may return a tuple (vals, indx, bitmatrix) or a
# SparseMatrix depending on the triton_kernels version.
if
isinstance
(
topk_result
,
tuple
):
topk_weights
,
topk_ids_raw
,
_
=
topk_result
else
:
topk_weights
=
topk_result
.
vals
topk_ids_raw
=
topk_result
.
indx
# topk_ids_raw contains global expert IDs - remap to local.
topk_ids
=
expert_map
[
topk_ids_raw
.
to
(
torch
.
long
)]
local_num_experts
=
w1
.
shape
[
0
]
...
...
@@ -271,8 +185,9 @@ def triton_kernel_moe_forward(
effective_expert_map
=
None
effective_global_num_experts
=
local_num_experts
else
:
routing_data
,
gather_idx
,
scatter_idx
=
legacy_routing
(
gating_output
,
topk
,
sm_first
=
not
renormalize
topk_ids
=
topk_ids_raw
.
to
(
torch
.
long
)
routing_data
,
gather_idx
,
scatter_idx
=
make_routing_data
(
topk_ids
,
topk_weights
,
gating_output
.
shape
[
-
1
]
)
effective_expert_map
=
expert_map
effective_global_num_experts
=
global_num_experts
...
...
@@ -539,10 +454,31 @@ def make_routing_data(
# matmul_ogs expects invalid topk_weights to be -1s
topk_weights
=
torch
.
where
(
topk_ids
==
-
1
,
-
1.0
,
topk_weights
)
routing_data
,
gather_indx
,
scatter_indx
=
legacy_routing_from_bitmatrix
(
bitmatrix
,
topk_weights
,
topk_ids
,
num_local_experts
,
num_topk
)
if
use_legacy_triton_kernels
:
from
triton_kernels.routing
import
routing_from_bitmatrix
return
routing_from_bitmatrix
(
bitmatrix
,
topk_weights
,
topk_ids
,
num_local_experts
,
num_topk
)
sparse_logits
=
SparseMatrix
(
indx
=
topk_ids
,
vals
=
topk_weights
,
mask
=
bitmatrix
)
dispatch_indx
=
sparse_logits
.
mask_metadata
.
row_sorted_indx
combine_indx
=
sparse_logits
.
mask_metadata
.
col_sorted_indx
ragged_batch_metadata
=
make_ragged_tensor_metadata
(
sparse_logits
.
mask_metadata
.
col_sum
,
dispatch_indx
.
shape
[
0
],
)
gate_scal
=
sparse_logits
.
vals
.
flatten
()[
combine_indx
]
routing_data
=
RoutingData
(
gate_scal
,
ragged_batch_metadata
.
block_sizes
,
num_local_experts
,
num_topk
,
ragged_batch_metadata
,
)
gather_indx
=
GatherIndx
(
combine_indx
,
dispatch_indx
)
scatter_indx
=
ScatterIndx
(
dispatch_indx
,
combine_indx
)
return
routing_data
,
gather_indx
,
scatter_indx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment