Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
787626fb
Commit
787626fb
authored
Jul 22, 2022
by
ltqin
Browse files
add n repeate function
parent
da047ec1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
96 additions
and
69 deletions
+96
-69
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
...de/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
+95
-66
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+1
-3
No files found.
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
View file @
787626fb
...
...
@@ -17,17 +17,60 @@ template <index_t BlockSize,
index_t
NPerXDL
,
index_t
RegSizePerXdlops
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
MThreadSliceSize
,
index_t
NThreadSliceSize
>
index_t
NRepeat
>
struct
BlockwiseSoftmax_V1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static_assert
(
MRepeat
==
1
,
"Now MRepeat must equal 1"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
MThreadSliceSize
=
1
;
static
constexpr
index_t
WaveSize
=
64
;
constexpr
static
auto
c_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
RegSizePerXdlops
>
{}));
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
RegSizePerXdlops
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{})));
using
ThreadwiseMaxReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Max
,
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
using
ThreadClusterLengths_M_K
=
Sequence
<
MPerXDL
,
WaveSize
/
MPerXDL
>
;
using
ThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Max
,
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Add
,
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
using
ThreadwiseSumReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Add
,
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
template
<
typename
CThreadBuffer
>
__host__
__device__
static
void
Run
(
CThreadBuffer
&
c_thread_buf
,
void
*
__restrict__
p_shared
)
{
...
...
@@ -44,74 +87,60 @@ struct BlockwiseSoftmax_V1
accu_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
));
auto
&
xdlops_out
=
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{});
using
ThreadReduceSrcDesc_M_K
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
Number
<
c_thread_desc
.
GetLength
(
I2
)
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{})));
using
ThreadwiseMaxReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Max
,
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
ThreadwiseMaxReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
max_value_buf
);
// const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
// max value for one thread
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
auto
&
xdlops_out
=
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{});
using
ThreadClusterLengths_M_K
=
Sequence
<
32
,
2
>
;
using
ThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Max
,
false
,
// param ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
;
ThreadwiseMaxReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
max_value_buf
);
});
//{const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
// ignore = p_reduce_work_buffer;}
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
block_sync_lds
();
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I0
));
block_sync_lds
();
//
printf("\n"
);
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
//
{const index_t thread_local_id = get_thread_local_1d_id(
);
// printf("thread id: %d, Max: %f\t\t",
thread_local_id,
max_value_buf[I0]);
}
// softmax
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadClusterArrangeOrder
,
reduce
::
Add
,
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
using
ThreadwiseSumReduce
=
ThreadwiseReduction
<
AccDataType
,
ThreadReduceSrcDesc_M_K
,
ThreadReduceDstDesc_M
,
reduce
::
Add
,
false
,
// ignored
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
;
static_for
<
0
,
c_thread_desc
.
GetLength
(
I2
),
1
>
{}([
&
](
auto
iK
)
{
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
math
::
exp
(
xdlops_out
.
template
AsType
<
float
>()[
iK
]
-
max_value_buf
(
I0
));
});
ThreadwiseSumReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
accu_value_buf
);
block_sync_lds
();
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I0
));
block_sync_lds
();
static_for
<
0
,
c_thread_desc
.
GetLength
(
I2
),
1
>
{}([
&
](
auto
iK
)
{
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
xdlops_out
.
template
AsType
<
float
>()[
iK
]
/
accu_value_buf
(
I0
);
});
{
// calculate exp for elements
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
auto
&
xdlops_out
=
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{});
static_for
<
0
,
RegSizePerXdlops
,
1
>
{}([
&
](
auto
iK
)
{
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
math
::
exp
(
xdlops_out
.
template
AsType
<
float
>()[
iK
]
-
max_value_buf
(
I0
));
});
});
// sum data
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
auto
&
xdlops_out
=
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{});
ThreadwiseSumReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
accu_value_buf
);
block_sync_lds
();
});
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I0
));
block_sync_lds
();
// change elements
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
auto
&
xdlops_out
=
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{});
static_for
<
0
,
c_thread_desc
.
GetLength
(
I2
),
1
>
{}([
&
](
auto
iK
)
{
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
xdlops_out
.
template
AsType
<
float
>()[
iK
]
/
accu_value_buf
(
I0
);
});
});
}
}
};
};
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
787626fb
...
...
@@ -480,9 +480,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
NPerXDL
,
blockwise_gemm
.
GetRegSizePerXdlops
(),
MXdlPerWave
,
NXdlPerWave
,
1
,
1
>
;
NXdlPerWave
>
;
BlockwiseSoftmax
::
Run
(
c_thread_buf
,
p_reduce_work_buffer
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment