Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9ec592a6
Commit
9ec592a6
authored
Jun 05, 2023
by
guangzlu
Browse files
added new dropout based on global position for fwd pass
parent
a3def557
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
254 additions
and
31 deletions
+254
-31
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v5.cpp
...ale_softmax_gemm/batched_multihead_attention_train_v5.cpp
+4
-4
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+146
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
...vice_batched_multihead_attention_forward_xdl_cshuffle.hpp
+12
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+81
-25
include/ck/utility/philox_rand.hpp
include/ck/utility/philox_rand.hpp
+11
-0
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_train_v5.cpp
View file @
9ec592a6
...
@@ -32,7 +32,7 @@ Kernel outputs:
...
@@ -32,7 +32,7 @@ Kernel outputs:
#define PRINT_HOST 0
#define PRINT_HOST 0
#define USING_MASK 0
#define USING_MASK 0
#define DIM
128
// DIM should be a multiple of 8.
#define DIM
64
// DIM should be a multiple of 8.
#include <iostream>
#include <iostream>
#include <numeric>
#include <numeric>
...
@@ -715,13 +715,13 @@ int run(int argc, char* argv[])
...
@@ -715,13 +715,13 @@ int run(int argc, char* argv[])
ck
::
index_t
M
=
500
;
// 512
ck
::
index_t
M
=
500
;
// 512
ck
::
index_t
K
=
DIM
;
ck
::
index_t
K
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
O
=
DIM
;
ck
::
index_t
G0
=
4
;
// 54
ck
::
index_t
G0
=
1
;
// 54
ck
::
index_t
G1
=
6
;
// 16
ck
::
index_t
G1
=
2
;
// 16
bool
input_permute
=
false
;
bool
input_permute
=
false
;
bool
output_permute
=
false
;
bool
output_permute
=
false
;
float
p_drop
=
0.
0
;
float
p_drop
=
0.
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
...
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
9ec592a6
...
@@ -122,6 +122,152 @@ struct BlockwiseDropout
...
@@ -122,6 +122,152 @@ struct BlockwiseDropout
});
});
}
}
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout_v1r1
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
tmp_index
=
tmp_index
+
1
;
});
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout_v1r2
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
);
}
ushort
tmp_id
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
tmp_id
[
i
*
4
+
j
]
=
element_global_1d_id
+
i
*
8
;
}
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
z_thread_buf
(
offset
)
=
tmp_id
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
});
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout_v2
(
CThreadBuffer
&
in_thread_buf
,
ZThreadBuffer
&
z_thread_buf
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
z_thread_buf
(
offset
)
<=
p_dropout_16bits
,
in_thread_buf
(
offset
));
});
});
}
// get raw z matrix with random number for shuffle
template
<
typename
ZThreadBuffer
>
__host__
__device__
void
GenerateZMatrix
(
ck
::
philox
&
ph
,
index_t
element_global_1d_id
,
ZThreadBuffer
&
z_thread_buf
,
index_t
MRaw
)
{
// if(get_thread_global_1d_id() == 0){
// printf("MRepeat & KRepeat is %d , %d . \n", MRepeat, KRepeat);
// }
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
philox_calls
=
tmp_size
/
4
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
ph
.
get_random_4x16
((
tmp
+
i
*
4
),
element_global_1d_id
+
i
*
8
*
MRaw
);
}
ushort
tmp_id
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
tmp_id
[
i
*
4
+
j
]
=
element_global_1d_id
+
i
*
8
*
MRaw
;
}
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
z_thread_buf
(
offset
)
=
tmp_id
[
tmp_index
];
tmp_index
=
tmp_index
+
1
;
});
});
}
ushort
p_dropout_16bits
;
ushort
p_dropout_16bits
;
DataType
p_dropout_rescale
;
DataType
p_dropout_rescale
;
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
9ec592a6
...
@@ -79,7 +79,9 @@ __global__ void
...
@@ -79,7 +79,9 @@ __global__ void
const
ushort
p_dropout_in_16bits
,
const
ushort
p_dropout_in_16bits
,
const
GemmAccDataType
p_dropout_rescale
,
const
GemmAccDataType
p_dropout_rescale
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
,
const
index_t
MRaw
,
const
index_t
NRaw
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -131,6 +133,9 @@ __global__ void
...
@@ -131,6 +133,9 @@ __global__ void
p_dropout_in_16bits
,
p_dropout_in_16bits
,
p_dropout_rescale
,
p_dropout_rescale
,
ph
,
ph
,
g_idx
,
MRaw
,
NRaw
,
i
);
i
);
}
}
}
}
...
@@ -160,6 +165,9 @@ __global__ void
...
@@ -160,6 +165,9 @@ __global__ void
p_dropout_in_16bits
,
p_dropout_in_16bits
,
p_dropout_rescale
,
p_dropout_rescale
,
ph
,
ph
,
g_idx
,
MRaw
,
NRaw
,
0
);
0
);
}
}
#else
#else
...
@@ -803,7 +811,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -803,7 +811,9 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg
.
p_dropout_in_16bits_
,
arg
.
p_dropout_in_16bits_
,
arg
.
p_dropout_rescale_
,
arg
.
p_dropout_rescale_
,
arg
.
seed_
,
arg
.
seed_
,
arg
.
offset_
);
arg
.
offset_
,
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
0
],
arg
.
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
};
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
9ec592a6
...
@@ -447,6 +447,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -447,6 +447,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
ushort
p_dropout_in_16bits
,
const
ushort
p_dropout_in_16bits
,
FloatGemmAcc
p_dropout_rescale
,
FloatGemmAcc
p_dropout_rescale
,
ck
::
philox
&
ph
,
ck
::
philox
&
ph
,
const
index_t
g_idx
,
const
index_t
MRaw
,
const
index_t
NRaw
,
const
index_t
block_idx_m
)
const
index_t
block_idx_m
)
{
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -857,7 +860,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -857,7 +860,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
m0
,
// MRepeat
m0
,
// MRepeat
I1
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
m1
,
// MWaveId
n1
,
// NWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
m2
,
// MPerXdl
...
@@ -888,7 +891,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -888,7 +891,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
Sequence
<
I1
,
// MBlockId
Sequence
<
I1
,
// MBlockId
I1
,
// NBlockID
I1
,
// NBlockID
m0
,
// MRepeat
m0
,
// MRepeat
I1
,
// NRepeat
n0
,
// NRepeat
m1
,
// MWaveId
m1
,
// MWaveId
n1
,
// NWaveId
n1
,
// NWaveId
m2
,
// MPerXdl
m2
,
// MPerXdl
...
@@ -917,7 +920,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -917,7 +920,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
{
block_sync_lds
();
block_sync_lds
();
}
}
do
do
{
{
auto
n_block_data_idx_on_grid
=
auto
n_block_data_idx_on_grid
=
...
@@ -1012,31 +1015,84 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1012,31 +1015,84 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
if
constexpr
(
IsDropout
)
// dropout
if
constexpr
(
IsDropout
)
// dropout
{
{
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLengths
();
constexpr
auto
M0
=
c_block_lengths
[
I0
];
constexpr
auto
N0
=
c_block_lengths
[
I1
];
constexpr
auto
M1
=
c_block_lengths
[
I2
];
constexpr
auto
N1
=
c_block_lengths
[
I3
];
constexpr
auto
M2
=
c_block_lengths
[
I4
];
constexpr
auto
N2
=
c_block_lengths
[
I5
];
constexpr
auto
N3
=
c_block_lengths
[
I6
];
constexpr
auto
N4
=
c_block_lengths
[
I7
];
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
,
N3
,
N4
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
,
6
,
7
>
{}));
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id
=
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// save z to global
// save z to global
if
(
p_z_grid
)
if
(
p_z_grid
)
{
{
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
i
)
{
blockwise_dropout
.
template
ApplyDropout_v1r2
<
decltype
(
acc_thread_buf
),
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_tenor_buffer
),
false
>(
false
,
acc_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
);
decltype
(
n0
),
decltype
(
i
)>(
z_thread_copy_vgpr_to_global
.
Run
(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_thread_copy_vgpr_to_global
.
Run
(
z_tenor_buffer
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_buf
);
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
));
});
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
0
,
0
,
-
(
n0
.
value
),
0
,
0
,
0
,
0
,
0
,
0
));
z_grid_buf
);
// static_for<0, n0, 1>{}([&](auto i) {
// blockwise_dropout.template ApplyDropout<decltype(acc_thread_buf),
// decltype(z_tenor_buffer),
// false,
// decltype(n0),
// decltype(i)>(
// acc_thread_buf, ph, global_elem_id + id_step * i.value,
// z_tenor_buffer);
//
// z_thread_copy_vgpr_to_global.Run(
// z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
// z_tenor_buffer,
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// z_grid_buf);
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, 1, 0, 0, 0, 0, 0, 0));
//});
// z_thread_copy_vgpr_to_global.MoveDstSliceWindow(
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
// make_multi_index(0, 0, 0, -(n0.value), 0, 0, 0, 0, 0, 0));
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
make_multi_index
(
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
));
...
@@ -1045,8 +1101,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1045,8 +1101,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
{
{
// ignore = z_grid_buf;
// ignore = z_grid_buf;
// P_dropped
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
false
>(
blockwise_dropout
.
template
ApplyDropout
_v1r1
<
decltype
(
acc_thread_buf
),
false
>(
acc_thread_buf
,
ph
);
acc_thread_buf
,
ph
,
global_elem_id
);
}
}
}
}
...
...
include/ck/utility/philox_rand.hpp
View file @
9ec592a6
...
@@ -84,6 +84,17 @@ class philox
...
@@ -84,6 +84,17 @@ class philox
out_tmp
[
3
]
=
tmp_ph
.
w
;
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
}
__device__
void
get_random_4x16
(
ushort
*
out
,
const
unsigned
long
long
subsequence
)
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
out
[
0
]
=
static_cast
<
ushort
>
(
tmp_ph
.
x
);
out
[
1
]
=
static_cast
<
ushort
>
(
tmp_ph
.
y
);
out
[
2
]
=
static_cast
<
ushort
>
(
tmp_ph
.
z
);
out
[
3
]
=
static_cast
<
ushort
>
(
tmp_ph
.
w
);
}
private:
private:
struct
ull2
struct
ull2
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment