Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
935422a4
Commit
935422a4
authored
Dec 20, 2022
by
fsx950223
Browse files
add reduce operation
parent
3679054a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
42 additions
and
26 deletions
+42
-26
example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp
..._sparse_embedding/sparse_embedding3_forward_layernorm.cpp
+12
-9
include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp
...evice/impl/device_sparse_embeddings_forward_layernorm.hpp
+14
-5
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
...gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
+16
-12
No files found.
example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp
View file @
935422a4
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/device_memory.hpp"
...
@@ -24,15 +25,16 @@ using GammaDataType = ck::half_t;
...
@@ -24,15 +25,16 @@ using GammaDataType = ck::half_t;
using
BetaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
using
OutType
=
ck
::
half_t
;
using
OutType
=
ck
::
half_t
;
using
ReduceOperation
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
DeviceInstance_fp16_e256
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
256
,
1
,
1
,
3
>
;
using
DeviceInstance_fp16_e256
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
ReduceOperation
,
256
,
1
,
256
,
1
,
256
,
1
,
1
,
3
>
;
using
DeviceInstance_fp16_e512
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
512
,
1
,
2
,
3
>
;
using
DeviceInstance_fp16_e512
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
ReduceOperation
,
256
,
1
,
256
,
1
,
512
,
1
,
2
,
3
>
;
using
DeviceInstance_fp16_e768
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
768
,
1
,
1
,
3
>
;
using
DeviceInstance_fp16_e768
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
ReduceOperation
,
256
,
1
,
256
,
1
,
768
,
1
,
1
,
3
>
;
using
DeviceInstance_fp16_e1024
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
1024
,
1
,
2
,
3
>
;
using
DeviceInstance_fp16_e1024
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
ReduceOperation
,
256
,
1
,
256
,
1
,
1024
,
1
,
2
,
3
>
;
using
DeviceInstance_fp16_e1536
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
1536
,
1
,
2
,
3
>
;
using
DeviceInstance_fp16_e1536
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
ReduceOperation
,
256
,
1
,
256
,
1
,
1536
,
1
,
2
,
3
>
;
using
DeviceInstance_fp16_e2048
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
2048
,
1
,
2
,
3
>
;
using
DeviceInstance_fp16_e2048
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
ReduceOperation
,
256
,
1
,
256
,
1
,
2048
,
1
,
2
,
3
>
;
using
DeviceInstance_fp16_e4096
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
4096
,
1
,
8
,
3
>
;
using
DeviceInstance_fp16_e4096
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
ReduceOperation
,
256
,
1
,
256
,
1
,
4096
,
1
,
8
,
3
>
;
using
DeviceInstance_fp16_e8192
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
8192
,
1
,
8
,
3
>
;
using
DeviceInstance_fp16_e8192
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
ReduceOperation
,
256
,
1
,
256
,
1
,
8192
,
1
,
8
,
3
>
;
template
<
typename
emb_type
,
ck
::
index_t
dim
>
struct
emb_kernel
{};
template
<
typename
emb_type
,
ck
::
index_t
dim
>
struct
emb_kernel
{};
...
@@ -134,7 +136,8 @@ int main()
...
@@ -134,7 +136,8 @@ int main()
beta_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
(),
current_dim
,
current_dim
,
index_length
,
index_length
,
epsilon
);
epsilon
,
ReduceOperation
{});
std
::
cout
<<
"Dim:"
<<
current_dim
<<
", kernel:"
<<
device_instance
.
GetTypeString
()
std
::
cout
<<
"Dim:"
<<
current_dim
<<
", kernel:"
<<
device_instance
.
GetTypeString
()
<<
std
::
endl
<<
std
::
endl
<<
std
::
flush
;
<<
std
::
flush
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp
View file @
935422a4
...
@@ -24,6 +24,7 @@ template <typename EmbType,
...
@@ -24,6 +24,7 @@ template <typename EmbType,
typename
BetaDataType
,
typename
BetaDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
OutType
,
typename
OutType
,
typename
ReduceOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
DimClusterSize
,
ck
::
index_t
DimClusterSize
,
ck
::
index_t
RowClusterSize
,
ck
::
index_t
RowClusterSize
,
...
@@ -48,7 +49,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
...
@@ -48,7 +49,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
const
BetaDataType
*
p_beta
,
const
BetaDataType
*
p_beta
,
const
ck
::
index_t
EmbeddingDim
,
const
ck
::
index_t
EmbeddingDim
,
const
ck
::
index_t
IndexLength
,
const
ck
::
index_t
IndexLength
,
const
AccDataType
epsilon
)
const
AccDataType
epsilon
,
const
ReduceOperation
reduce_op
)
:
p_out_
(
p_out
),
:
p_out_
(
p_out
),
p_embs_
(
p_embs
),
p_embs_
(
p_embs
),
p_indexs_
(
p_indexs
),
p_indexs_
(
p_indexs
),
...
@@ -56,7 +58,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
...
@@ -56,7 +58,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
p_beta_
(
p_beta
),
p_beta_
(
p_beta
),
EmbeddingDim_
(
EmbeddingDim
),
EmbeddingDim_
(
EmbeddingDim
),
IndexLength_
(
IndexLength
),
IndexLength_
(
IndexLength
),
epsilon_
(
epsilon
)
epsilon_
(
epsilon
),
reduce_op_
(
reduce_op
)
{
{
grid_size_
=
(
IndexLength
+
DimClusterSize
-
1
)
/
DimClusterSize
;
grid_size_
=
(
IndexLength
+
DimClusterSize
-
1
)
/
DimClusterSize
;
}
}
...
@@ -69,6 +72,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
...
@@ -69,6 +72,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
ck
::
index_t
EmbeddingDim_
;
ck
::
index_t
EmbeddingDim_
;
ck
::
index_t
IndexLength_
;
ck
::
index_t
IndexLength_
;
AccDataType
epsilon_
;
AccDataType
epsilon_
;
ReduceOperation
reduce_op_
;
size_t
grid_size_
;
size_t
grid_size_
;
};
};
...
@@ -81,7 +85,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
...
@@ -81,7 +85,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
const
void
*
p_beta
,
const
void
*
p_beta
,
ck
::
index_t
EmbeddingDim
,
ck
::
index_t
EmbeddingDim
,
ck
::
index_t
IndexLength
,
ck
::
index_t
IndexLength
,
const
AccDataType
epsilon
)
const
AccDataType
epsilon
,
const
ReduceOperation
reduce_op
)
{
{
return
std
::
make_unique
<
Argument
>
(
reinterpret_cast
<
OutType
*>
(
p_out
),
return
std
::
make_unique
<
Argument
>
(
reinterpret_cast
<
OutType
*>
(
p_out
),
p_embs
,
p_embs
,
...
@@ -90,7 +95,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
...
@@ -90,7 +95,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
reinterpret_cast
<
const
BetaDataType
*>
(
p_beta
),
reinterpret_cast
<
const
BetaDataType
*>
(
p_beta
),
EmbeddingDim
,
EmbeddingDim
,
IndexLength
,
IndexLength
,
epsilon
);
epsilon
,
reduce_op
);
}
}
using
GridwiseSparseEmbedding
=
using
GridwiseSparseEmbedding
=
...
@@ -101,6 +107,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
...
@@ -101,6 +107,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
AccDataType
,
AccDataType
,
OutType
,
OutType
,
decltype
(
MakeOutputDescriptor
(
1
,
1
)),
decltype
(
MakeOutputDescriptor
(
1
,
1
)),
ReduceOperation
,
BlockSize
,
BlockSize
,
DimClusterSize
,
DimClusterSize
,
RowClusterSize
,
RowClusterSize
,
...
@@ -124,6 +131,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
...
@@ -124,6 +131,7 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
AccDataType
,
AccDataType
,
OutType
,
OutType
,
decltype
(
out_desc
),
decltype
(
out_desc
),
ReduceOperation
,
NumEmbeddings
>
;
NumEmbeddings
>
;
float
avg_time
=
0
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
@@ -137,7 +145,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
...
@@ -137,7 +145,8 @@ struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator
arg
.
p_gamma_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
p_beta_
,
out_desc
,
out_desc
,
arg
.
epsilon_
);
arg
.
epsilon_
,
arg
.
reduce_op_
);
return
(
avg_time
);
return
(
avg_time
);
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
View file @
935422a4
...
@@ -18,6 +18,7 @@ template <typename GridwiseSparseEmbedding,
...
@@ -18,6 +18,7 @@ template <typename GridwiseSparseEmbedding,
typename
AccDataType
,
typename
AccDataType
,
typename
OutType
,
typename
OutType
,
typename
OutGridDesc
,
typename
OutGridDesc
,
typename
ReduceOperation
,
ck
::
index_t
NumEmbeddings
>
ck
::
index_t
NumEmbeddings
>
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -29,9 +30,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
...
@@ -29,9 +30,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
const
GammaDataType
*
p_gamma
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
BetaDataType
*
p_beta
,
const
OutGridDesc
out_grid_desc
,
const
OutGridDesc
out_grid_desc
,
const
AccDataType
epsilon
)
const
AccDataType
epsilon
,
const
ReduceOperation
reduce_op
)
{
{
GridwiseSparseEmbedding
::
Run
(
p_out
,
p_embs
,
p_indexes
,
p_gamma
,
p_beta
,
out_grid_desc
,
epsilon
);
GridwiseSparseEmbedding
::
Run
(
p_out
,
p_embs
,
p_indexes
,
p_gamma
,
p_beta
,
out_grid_desc
,
epsilon
,
reduce_op
);
}
}
template
<
typename
EmbType
,
template
<
typename
EmbType
,
...
@@ -41,6 +43,7 @@ template <typename EmbType,
...
@@ -41,6 +43,7 @@ template <typename EmbType,
typename
AccDataType
,
typename
AccDataType
,
typename
OutType
,
typename
OutType
,
typename
OutGridDesc
,
typename
OutGridDesc
,
typename
ReduceOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
DimClusterSize
,
ck
::
index_t
DimClusterSize
,
ck
::
index_t
RowClusterSize
,
ck
::
index_t
RowClusterSize
,
...
@@ -91,7 +94,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
...
@@ -91,7 +94,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
const
GammaDataType
*
p_gamma
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
BetaDataType
*
p_beta
,
const
OutGridDesc
,
const
OutGridDesc
,
const
AccDataType
epsilon
)
const
AccDataType
epsilon
,
const
ReduceOperation
reduce_op
)
{
{
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
...
@@ -149,11 +153,11 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
...
@@ -149,11 +153,11 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
auto
thread_offset
=
(
thread_row_cluster_id
+
i_row_sub_
*
RowClusterSize
)
*
auto
thread_offset
=
(
thread_row_cluster_id
+
i_row_sub_
*
RowClusterSize
)
*
sizeof
(
EmbType
)
*
RowVectorSize
;
sizeof
(
EmbType
)
*
RowVectorSize
;
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
IndexType
index
=
index_bufs
[
i_embedding_
.
value
][
Number
<
current_dim
>
{}];
IndexType
index
=
index_bufs
[
i_embedding_
][
Number
<
current_dim
>
{}];
int32x4_t
emb_res
=
make_wave_buffer_resource_with_default_range
(
int32x4_t
emb_res
=
make_wave_buffer_resource_with_default_range
(
p_embs
[
i_embedding_
.
value
]
+
index
*
RowPerBlock
);
p_embs
[
i_embedding_
]
+
index
*
RowPerBlock
);
emb_vectors
(
i_embedding_
.
value
).
template
AsType
<
src_vector_t
>()(
I0
)
=
emb_vectors
(
i_embedding_
).
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res
,
thread_offset
,
0
);
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res
,
thread_offset
,
0
);
});
});
...
@@ -161,8 +165,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
...
@@ -161,8 +165,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
in_thread_bufs
(
i_embedding_
.
value
)(
Number
<
register_offset
>
{})
=
in_thread_bufs
(
i_embedding_
)(
Number
<
register_offset
>
{})
=
emb_vectors
[
i_embedding_
.
value
].
template
AsType
<
EmbType
>()[
i_row_vec_
];
emb_vectors
[
i_embedding_
].
template
AsType
<
EmbType
>()[
i_row_vec_
];
});
});
});
});
});
});
...
@@ -174,8 +178,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
...
@@ -174,8 +178,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
acc_thread_buf
(
Number
<
register_offset
>
{})
+=
ck
::
type_convert
<
AccDataType
>
(
reduce_op
(
acc_thread_buf
(
Number
<
register_offset
>
{})
,
acc_thread_buf
(
Number
<
register_offset
>
{}),
ck
::
type_convert
<
AccDataType
>
(
in_thread_bufs
(
i_embedding_
.
value
)(
Number
<
register_offset
>
{}));
in_thread_bufs
(
i_embedding_
)(
Number
<
register_offset
>
{}))
)
;
});
});
});
});
});
});
...
@@ -237,8 +241,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
...
@@ -237,8 +241,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
ck
::
static_for
<
0
,
DimPerBlock
,
1
>
{}([
&
](
auto
i_idx_
)
{
ck
::
static_for
<
0
,
DimPerBlock
,
1
>
{}([
&
](
auto
i_idx_
)
{
// prefer use s_load
// prefer use s_load
ck
::
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
ck
::
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
index_bufs
(
i_embedding_
.
value
)(
i_idx_
)
=
index_bufs
(
i_embedding_
)(
i_idx_
)
=
p_indexes
[
i_embedding_
.
value
][
index_start
+
i_idx_
.
value
];
p_indexes
[
i_embedding_
][
index_start
+
i_idx_
.
value
];
});
});
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment