Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
39274378
"docs/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "c38ecd2eeac6d8497160343ba8080d3b9e780bd3"
Commit
39274378
authored
Aug 09, 2023
by
raman.jana
Browse files
assembly functions for softmax primitives
parent
38f48480
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
890 additions
and
0 deletions
+890
-0
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
+139
-0
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
...sor_operation/gpu/block/reduction_functions_blockwise.hpp
+38
-0
include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp
...r_operation/gpu/thread/reduction_functions_threadwise.hpp
+40
-0
include/ck/utility/reduction_common.hpp
include/ck/utility/reduction_common.hpp
+34
-0
include/ck/utility/reduction_functions_accumulate.hpp
include/ck/utility/reduction_functions_accumulate.hpp
+28
-0
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+514
-0
include/ck/utility/synchronization.hpp
include/ck/utility/synchronization.hpp
+97
-0
No files found.
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
View file @
39274378
...
...
@@ -130,4 +130,143 @@ struct BlockwiseSoftmax
BufferType
sum_value_buf
;
};
template
<
index_t
BlockSize
,
typename
AccDataType
,
typename
ThreadMap_M_K
,
// thread_id to m_k
typename
ThreadClusterDesc_M_K
,
typename
ThreadSliceDesc_M_K
,
bool
IgnoreNaN
=
false
>
struct
BlockwiseSoftmax_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
MRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
);
static
constexpr
index_t
KRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I1
);
using
ThreadSliceDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
))));
using
ThreadwiseMaxReduce
=
typename
conditional
<
IgnoreNaN
,
ThreadwiseReductionDouble
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
Max3
,
false
,
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max3
,
AccDataType
>>
,
ThreadwiseReductionDouble
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
Max3
,
false
>>::
type
;
using
ThreadwiseSumReduce
=
typename
conditional
<
IgnoreNaN
,
ThreadwiseReduction
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
fast_Add
,
false
,
detail
::
AccumulateWithNanIgnore
<
reduce
::
fast_Add
,
AccDataType
>>
,
ThreadwiseReduction
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
fast_Add
,
false
>>::
type
;
using
ThreadClusterLengths_M_K
=
decltype
(
ThreadClusterDesc_M_K
{}.
GetLengths
());
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction_v2
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadMap_M_K
,
reduce
::
fast_Max
,
false
>
;
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction_v2
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadMap_M_K
,
reduce
::
Add
,
false
>
;
using
BufferType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MRepeat
,
true
>
;
template
<
typename
CThreadBuffer
,
typename
WorkspaceBuffer
>
__host__
__device__
void
Run
(
CThreadBuffer
&
in_thread_buf
,
WorkspaceBuffer
&
reduce_work_buf
)
{
// find max value
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
max_value_buf
(
I
)
=
reduce
::
Max
::
template
GetIdentityValue
<
AccDataType
>();
});
ThreadwiseMaxReduce
::
Reduce
(
in_thread_buf
,
max_value_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseMaxReduce
::
WaveReduce
(
reduce_work_buf
,
max_value_buf
(
I
));
});
if
(
IgnoreNaN
)
{
// calculate exp for elements, P=exp(s-max)
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
ck
::
math
::
isnan
(
in_thread_buf
[
offset
])
?
0
:
math
::
exp
(
in_thread_buf
[
offset
]
-
max_value_buf
(
iM
));
});
});
}
else
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
math
::
exp
(
in_thread_buf
[
offset
]
-
max_value_buf
(
iM
));
});
});
}
// sum data
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
sum_value_buf
(
I
)
=
reduce
::
fast_Add
::
template
GetIdentityValue
<
AccDataType
>();
});
ThreadwiseSumReduce
::
Reduce
(
in_thread_buf
,
sum_value_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseSumReduce
::
WaveReduce
(
reduce_work_buf
,
sum_value_buf
(
I
));
});
}
template
<
typename
CThreadBuffer
,
typename
LSEBuffer
>
__host__
__device__
void
RunWithPreCalcStats
(
CThreadBuffer
&
in_thread_buf
,
const
LSEBuffer
&
lse_thread_buf
)
{
// calculate exp for elements using pre-calculated stats LSE (log-sum-exp)
// Pi = exp(Si) / sum(exp(S0) + exp(S1) + ...)
// = exp(Si) / exp(log(sum(exp() + ...)))
// = exp(Si - log(sum(exp() + ...)))
if
(
IgnoreNaN
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
ck
::
math
::
isnan
(
in_thread_buf
[
offset
])
?
0
:
math
::
exp
(
in_thread_buf
[
offset
]
-
lse_thread_buf
[
iM
]);
});
});
}
else
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
math
::
exp
(
in_thread_buf
[
offset
]
-
lse_thread_buf
[
iM
]);
});
});
}
}
BufferType
max_value_buf
;
BufferType
sum_value_buf
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
View file @
39274378
...
...
@@ -152,6 +152,44 @@ struct PartitionedBlockwiseReduction_v2
in_out_value
=
work_buffer
[
offset
];
};
template
<
typename
BufferType
>
__device__
static
void
WaveReduce
(
BufferType
&
work_buffer
,
AccDataType
&
in_out_value
)
{
static_assert
(
is_same
<
typename
BufferType
::
type
,
AccDataType
>
{},
"Buffer data type should be consistent as AccDataType!"
);
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
Number
<
1
>
{}];
work_buffer
(
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
))
=
in_out_value
;
lds_waitcnt
(
0
);
static_for
<
0
,
cluster_len_shift
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
indOffset
=
1
<<
(
cluster_len_shift
-
1
-
I
());
if
(
thread_k_cluster_id
<
indOffset
)
{
index_t
offset1
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
);
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
AccDataType
opData1
=
work_buffer
[
offset1
];
AccDataType
opData2
=
work_buffer
[
offset2
];
Accumulation
::
Calculate
(
opData1
,
opData2
);
work_buffer
(
offset1
)
=
opData1
;
}
lds_waitcnt
(
0
);
});
index_t
offset
=
block_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
0
));
in_out_value
=
work_buffer
[
offset
];
};
};
// clang-format off
...
...
include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp
View file @
39274378
...
...
@@ -47,6 +47,46 @@ struct ThreadwiseReduction
};
};
// Assume
// 1) SrcDesc is known at compile-time
// 2) DstDesc is known at compile-time
// 3) SrcBuffer is static buffer
// 4) DstBuffer is static buffer
template
<
typename
AccDataType
,
typename
SrcThreadDesc_M_K
,
typename
DstThreadDesc_M
,
typename
OpReduce
,
bool
PropagateNan
,
typename
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
>
>
struct
ThreadwiseReductionDouble
{
static
constexpr
auto
src_thread_desc_m_k
=
SrcThreadDesc_M_K
{};
static
constexpr
auto
dst_thread_desc_m
=
DstThreadDesc_M
{};
static
constexpr
auto
src_length_m
=
src_thread_desc_m_k
.
GetLength
(
Number
<
0
>
{});
static
constexpr
auto
src_length_k
=
src_thread_desc_m_k
.
GetLength
(
Number
<
1
>
{});
static
constexpr
auto
dst_length_m
=
dst_thread_desc_m
.
GetLength
(
Number
<
0
>
{});
static_assert
(
src_length_m
==
dst_length_m
,
"lengths of source and dst buffer must match!"
);
using
Op
=
OpReduce
;
template
<
typename
SrcBufferType
,
typename
DstBufferType
>
__device__
static
void
Reduce
(
const
SrcBufferType
&
src_buf
,
DstBufferType
&
dst_buf
)
{
static_for
<
0
,
src_length_m
,
1
>
{}([
&
](
auto
iM
)
{
constexpr
index_t
out_offset
=
dst_thread_desc_m
.
CalculateOffset
(
make_tuple
(
iM
));
static_for
<
0
,
src_length_k
,
2
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
src_thread_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset1
=
src_thread_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
+
1
));
Accumulation
::
Calculate
(
dst_buf
(
Number
<
out_offset
>
{}),
src_buf
[
Number
<
offset
>
{}],
src_buf
[
Number
<
offset1
>
{}]);
});
});
};
};
// Assume
// 1) SrcDesc is known at compile-time
// 2) DstDesc is known at compile-time
...
...
include/ck/utility/reduction_common.hpp
View file @
39274378
...
...
@@ -37,4 +37,38 @@ constexpr __device__ index_t get_shift<1>()
return
(
0
);
}
template
<
typename
T
>
__host__
__device__
void
waveReduceSum
(
T
&
src
)
{
T
val
;
index_t
sumVal
=
0
;
// = __builtin_amdgcn_readlane(src,63);
asm
volatile
(
"
\n
\
v_add_f32 %0, %1, %1 row_shr:1 bound_ctrl:0
\n
\
v_add_f32 %0, %1, %0 row_shr:2 bound_ctrl:0
\n
\
v_add_f32 %0, %1, %0 row_shr:3 bound_ctrl:0
\n
\
v_nop
\n
\
v_nop
\n
\
v_add_f32 %0, %0, %0 row_shr:4 bound_ctrl:0
\n
\
v_nop
\n
\
v_nop
\n
\
v_add_f32 %0, %0, %0 row_shr:8 bound_ctrl:0
\n
\
v_nop
\n
\
v_nop
\n
\
v_add_f32 %1, %0, %0 row_bcast:15 row_mask:0xa
\n
\
v_nop
\n
\
v_nop
\n
\
v_add_f32 %1, %1, %1 row_bcast:31 row_mask:0xc
\n
\
v_nop
\n
\
v_nop
\n
\
v_readlane_b32 %2, %1, 63
\n
\
v_nop
\n
\
v_nop
\n
\
v_mov_b32 %1, %2
\n
\
"
:
"=v"
(
val
)
:
"v"
(
src
),
"s"
(
sumVal
),
"0"
(
val
));
}
}
// namespace ck
include/ck/utility/reduction_functions_accumulate.hpp
View file @
39274378
...
...
@@ -22,6 +22,13 @@ struct AccumulateWithNanIgnore
ReduceOperation
{}(
accuVal
,
currVal
);
}
};
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
,
AccDataType
currVal1
)
{
if
(
!
ck
::
math
::
isnan
(
currVal
)
&&
!
ck
::
math
::
isnan
(
currVal1
))
{
ReduceOperation
{}(
accuVal
,
currVal
,
currVal1
);
}
};
};
template
<
bool
PropagateNan
,
typename
ReduceOperation
,
typename
AccDataType
>
...
...
@@ -40,6 +47,10 @@ struct AccumulateWithNanCheck<false, ReduceOperation, AccDataType>
{
ReduceOperation
{}(
accuVal
,
currVal
);
};
__host__
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
,
AccDataType
currVal1
)
{
ReduceOperation
{}(
accuVal
,
currVal
,
currVal1
);
};
};
// Check for NaN; guarantees NaNs be propagated to result
...
...
@@ -59,6 +70,23 @@ struct AccumulateWithNanCheck<true, ReduceOperation, AccDataType>
ReduceOperation
{}(
accuVal
,
currVal
);
};
};
__host__
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
,
AccDataType
currVal1
)
{
using
ck
::
math
::
isnan
;
if
(
isnan
(
currVal
))
{
accuVal
=
currVal
;
}
else
if
(
isnan
(
currVal1
))
{
accuVal
=
currVal1
;
}
else
{
ReduceOperation
{}(
accuVal
,
currVal
,
currVal1
);
};
};
};
template
<
bool
PropagateNan
,
typename
ReduceOperation
,
typename
AccDataType
,
typename
IndexDataType
>
...
...
include/ck/utility/reduction_operator.hpp
View file @
39274378
...
...
@@ -239,6 +239,519 @@ struct AMax
}
};
struct
fast_Add
{
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
return
operation
==
InMemoryDataOperationEnum
::
AtomicAdd
||
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
,
"The data type is not supported by the Add accumulator!"
);
T
c
{
1.0
f
};
if
(
is_same
<
T
,
float
>::
value
)
{
asm
volatile
(
"
\n
\
v_fma_f32 %0, %0, %1, %2
\n
\
"
:
"=v"
(
a
)
:
"v"
(
c
),
"v"
(
b
),
"0"
(
a
));
}
else
if
(
is_same
<
T
,
half_t
>::
value
)
{
asm
volatile
(
"
\n
\
v_fma_f16 %0, %0, %1, %2
\n
\
"
:
"=v"
(
a
)
:
"v"
(
c
),
"v"
(
b
),
"0"
(
a
));
}
else
if
(
is_same
<
T
,
double
>::
value
)
{
asm
volatile
(
"
\n
\
v_fma_f64 %0, %0, %1, %2
\n
\
"
:
"=v"
(
a
)
:
"v"
(
c
),
"v"
(
b
),
"0"
(
a
));
}
else
{
a
=
a
+
b
;
}
}
};
struct
fast_Sub
{
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
return
operation
==
InMemoryDataOperationEnum
::
AtomicAdd
||
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
,
"The data type is not supported by the Add accumulator!"
);
T
c
{
-
1.0
f
};
if
(
is_same
<
T
,
float
>::
value
)
{
asm
volatile
(
"
\n
\
v_fma_f32 %0, %2, %1, %0
\n
\
"
:
"=v"
(
a
)
:
"v"
(
c
),
"v"
(
b
),
"0"
(
a
));
}
else
if
(
is_same
<
T
,
half_t
>::
value
)
{
asm
volatile
(
"
\n
\
v_fma_f16 %0, %2, %1, %0
\n
\
"
:
"=v"
(
a
)
:
"v"
(
c
),
"v"
(
b
),
"0"
(
a
));
}
else
if
(
is_same
<
T
,
double
>::
value
)
{
asm
volatile
(
"
\n
\
v_fma_f64 %0, %2, %1, %0
\n
\
"
:
"=v"
(
a
)
:
"v"
(
c
),
"v"
(
b
),
"0"
(
a
));
}
else
{
a
=
a
-
b
;
}
}
};
struct
Add2
{
template
<
typename
T
>
__host__
__device__
static
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
return
operation
==
InMemoryDataOperationEnum
::
AtomicAdd
||
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float2_t
>::
value
||
is_same
<
T
,
half2_t
>::
value
,
"The data type is not supported by the Add accumulator!"
);
T
c
{
1.0
f
};
if
(
is_same
<
T
,
float2_t
>::
value
)
{
asm
volatile
(
"
\n
\
v_pk_fma_f32 %0, %0, %1, %2
\n
\
"
:
"=v"
(
a
)
:
"v"
(
c
),
"v"
(
b
),
"0"
(
a
));
}
else
if
(
is_same
<
T
,
half2_t
>::
value
)
{
asm
volatile
(
"
\n
\
v_pk_fma_f16 %0, %0, %1, %2
\n
\
"
:
"=v"
(
a
)
:
"v"
(
c
),
"v"
(
b
),
"0"
(
a
));
}
}
};
struct
Sub2
{
template
<
typename
T
>
__host__
__device__
static
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
return
operation
==
InMemoryDataOperationEnum
::
AtomicAdd
||
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float2_t
>::
value
||
is_same
<
T
,
half2_t
>::
value
,
"The data type is not supported by the Add accumulator!"
);
T
c
{
-
1.0
f
};
if
(
is_same
<
T
,
float2_t
>::
value
)
{
asm
volatile
(
"
\n
\
v_pk_fma_f32 %0, %2, %1, %0
\n
\
"
:
"=v"
(
a
)
:
"v"
(
c
),
"v"
(
b
),
"0"
(
a
));
}
else
if
(
is_same
<
T
,
half2_t
>::
value
)
{
asm
volatile
(
"
\n
\
v_pk_fma_f16 %0, %2, %1, %0
\n
\
"
:
"=v"
(
a
)
:
"v"
(
c
),
"v"
(
b
),
"0"
(
a
));
}
}
};
struct
Mul2
{
template
<
typename
T
>
__host__
__device__
static
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
1.0
f
);
};
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float2_t
>::
value
||
is_same
<
T
,
half2_t
>::
value
,
"The data type is not supported by the Mul accumulator!"
);
if
(
is_same
<
T
,
float2_t
>::
value
)
{
asm
volatile
(
"
\n
\
v_pk_mul_f32 %0, %0, %1
\n
\
"
:
"=v"
(
a
)
:
"v"
(
b
),
"0"
(
a
));
}
else
if
(
is_same
<
T
,
half_t
>::
value
)
{
asm
volatile
(
"
\n
\
v_pk_mul_f16 %0, %0, %1
\n
\
"
:
"=v"
(
a
)
:
"v"
(
b
),
"0"
(
a
));
}
}
};
struct
fast_Max
{
template
<
typename
T
>
__host__
__device__
static
T
GetIdentityValue
()
{
return
NumericLimits
<
T
>::
Lowest
();
};
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
// ToChange: atomic_max to be added
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
is_same
<
T
,
float
>::
value
)
{
asm
volatile
(
"
\n
\
v_max_f32 %0, %0, %1
\n
\
"
:
"=v"
(
a
)
:
"v"
(
b
),
"0"
(
a
));
}
else
if
(
is_same
<
T
,
half_t
>::
value
)
{
asm
volatile
(
"
\n
\
v_max_f16 %0, %0, %1
\n
\
"
:
"=v"
(
a
)
:
"v"
(
b
),
"0"
(
a
));
}
else
{
if
(
a
<
b
)
a
=
b
;
}
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
{
a
=
b
;
changed
=
true
;
}
}
};
struct
Max3
{
template
<
typename
T
>
__host__
__device__
static
T
GetIdentityValue
()
{
return
NumericLimits
<
T
>::
Lowest
();
};
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
// ToChange: atomic_max to be added
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
void
operator
()(
T
&
a
,
T
b
,
T
c
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
is_same
<
T
,
float
>::
value
)
{
asm
volatile
(
"
\n
\
v_max3_f32 %0, %0, %1, %2
\n
\
"
:
"=v"
(
a
)
:
"v"
(
b
),
"v"
(
c
),
"0"
(
a
));
}
else
if
(
is_same
<
T
,
half_t
>::
value
)
{
asm
volatile
(
"
\n
\
v_max3_f16 %0, %0, %1, %2
\n
\
"
:
"=v"
(
a
)
:
"v"
(
b
),
"v"
(
c
),
"0"
(
a
));
}
else
{
asm
volatile
(
"
\n
\
v_max3_i32 %0, %0, %1, %2
\n
\
"
:
"=v"
(
a
)
:
"v"
(
b
),
"v"
(
c
),
"0"
(
a
));
}
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
{
a
=
b
;
changed
=
true
;
}
}
};
struct
fast_Min
{
template
<
typename
T
>
__host__
__device__
static
T
GetIdentityValue
()
{
return
NumericLimits
<
T
>::
Max
();
};
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
// ToChange: atomic_max to be added
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
is_same
<
T
,
float
>::
value
)
{
asm
volatile
(
"
\n
\
v_min_f32 %0, %0, %1
\n
\
"
:
"=v"
(
a
)
:
"v"
(
b
),
"0"
(
a
));
}
else
if
(
is_same
<
T
,
half_t
>::
value
)
{
asm
volatile
(
"
\n
\
v_min_f16 %0, %0, %1
\n
\
"
:
"=v"
(
a
)
:
"v"
(
b
),
"0"
(
a
));
}
else
if
(
is_same
<
T
,
double
>::
value
)
{
asm
volatile
(
"
\n
\
v_min_f64 %0, %0, %1
\n
\
"
:
"=v"
(
a
)
:
"v"
(
b
),
"0"
(
a
));
}
else
if
(
is_same
<
T
,
int32_t
>::
value
)
{
asm
volatile
(
"
\n
\
v_min_i32 %0, %0, %1
\n
\
"
:
"=v"
(
a
)
:
"v"
(
b
),
"0"
(
a
));
}
else
{
if
(
a
<
b
)
a
=
b
;
}
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
{
a
=
b
;
changed
=
true
;
}
}
};
struct
Min3
{
template
<
typename
T
>
__host__
__device__
static
T
GetIdentityValue
()
{
return
NumericLimits
<
T
>::
Max
();
};
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
// ToChange: atomic_max to be added
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
void
operator
()(
T
&
a
,
T
b
,
T
c
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
is_same
<
T
,
float
>::
value
)
{
asm
volatile
(
"
\n
\
v_min3_f32 %0, %0, %1, %2
\n
\
"
:
"=v"
(
a
)
:
"v"
(
b
),
"v"
(
c
),
"0"
(
a
));
}
else
if
(
is_same
<
T
,
half_t
>::
value
)
{
asm
volatile
(
"
\n
\
v_min3_f16 %0, %0, %1, %2
\n
\
"
:
"=v"
(
a
)
:
"v"
(
b
),
"v"
(
c
),
"0"
(
a
));
}
else
{
asm
volatile
(
"
\n
\
v_min3_i32 %0, %0, %1, %2
\n
\
"
:
"=v"
(
a
)
:
"v"
(
b
),
"v"
(
c
),
"0"
(
a
));
}
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
{
a
=
b
;
changed
=
true
;
}
}
};
template
<
typename
T
>
constexpr
T
GetIdentityValueForInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
...
...
@@ -288,5 +801,6 @@ struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Add, D
is_same
<
DataType
,
int32_t
>::
value
;
};
}
// namespace reduce
}
// namespace ck
include/ck/utility/synchronization.hpp
View file @
39274378
...
...
@@ -28,5 +28,102 @@ __device__ void s_nop()
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
__device__
void
wg_sync
()
{
asm
volatile
(
"\
s_barrier
\n
\
"
::
);
}
__device__
void
raise_priority
()
{
asm
volatile
(
"\
s_setprio(3)
\n
\
"
::
);
}
__device__
void
lower_priority
()
{
asm
volatile
(
"\
s_setprio(0)
\n
\
"
::
);
}
__device__
void
vm_waitcnt
(
const
uint32_t
cnt
)
{
if
(
cnt
==
0
)
{
asm
volatile
(
"\
s_waitcnt vmcnt(0)
\n
\
"
::
);
}
else
if
(
cnt
==
2
)
{
asm
volatile
(
"\
s_waitcnt vmcnt(2)
\n
\
"
::
);
}
else
if
(
cnt
==
4
)
{
asm
volatile
(
"\
s_waitcnt vmcnt(4)
\n
\
"
::
);
}
else
if
(
cnt
==
8
)
{
asm
volatile
(
"\
s_waitcnt vmcnt(8)
\n
\
"
::
);
}
else
if
(
cnt
==
12
)
{
asm
volatile
(
"\
s_waitcnt vmcnt(12)
\n
\
"
::
);
}
else
{
asm
volatile
(
"\
s_waitcnt vmcnt(16)
\n
\
"
::
);
}
}
__device__
void
lds_waitcnt
(
const
uint32_t
cnt
)
{
if
(
cnt
==
0
)
{
asm
volatile
(
"\
s_waitcnt lgkmcnt(0)
\n
\
"
::
);
}
else
if
(
cnt
==
4
)
{
asm
volatile
(
"\
s_waitcnt lgkmcnt(4)
\n
\
"
::
);
}
else
if
(
cnt
==
8
)
{
asm
volatile
(
"\
s_waitcnt lgkmcnt(8)
\n
\
"
::
);
}
else
if
(
cnt
==
12
)
{
asm
volatile
(
"\
s_waitcnt lgkmcnt(12)
\n
\
"
::
);
}
else
{
asm
volatile
(
"\
s_waitcnt lgkmcnt(16)
\n
\
"
::
);
}
}
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment