Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c0252636
Commit
c0252636
authored
Jul 22, 2022
by
ltqin
Browse files
regular code
parent
787626fb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
36 deletions
+39
-36
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
...de/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
+39
-36
No files found.
include/ck/tensor_operation/gpu/block/blockwise_softmax_v1.hpp
View file @
c0252636
...
@@ -76,17 +76,17 @@ struct BlockwiseSoftmax_V1
...
@@ -76,17 +76,17 @@ struct BlockwiseSoftmax_V1
{
{
// printf("c_thread_desc: {%d, %d, %d}", c_thread_desc.GetLength(I0).value,
// printf("c_thread_desc: {%d, %d, %d}", c_thread_desc.GetLength(I0).value,
// c_thread_desc.GetLength(I1).value, c_thread_desc.GetLength(I2).value);
// c_thread_desc.GetLength(I1).value, c_thread_desc.GetLength(I2).value);
auto
p_reduce_work_buffer
=
static_cast
<
AccDataType
*>
(
p_shared
);
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
AccDataType
*>
(
p_shared
),
BlockSize
);
//
// find max value
//
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
max_value_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
max_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
max_value_buf
(
I
)
=
reduce
::
Max
::
template
GetIdentityValue
<
AccDataType
>();
max_value_buf
(
I
)
=
reduce
::
Max
::
template
GetIdentityValue
<
AccDataType
>();
});
});
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
accu_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
// max value for one thread
// max value for one thread
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
...
@@ -99,47 +99,50 @@ struct BlockwiseSoftmax_V1
...
@@ -99,47 +99,50 @@ struct BlockwiseSoftmax_V1
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
// printf("thread id: %d, Max: %f\t\t",thread_local_id,max_value_buf[I0]);
// ignore = p_reduce_work_buffer;}
// ignore = p_reduce_work_buffer;}
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I0
));
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I0
));
block_sync_lds
();
block_sync_lds
();
// {const index_t thread_local_id = get_thread_local_1d_id();
// {const index_t thread_local_id = get_thread_local_1d_id();
// printf("thread id: %d, Max: %f\t\t", thread_local_id, max_value_buf[I0]);}
// printf("thread id: %d, Max: %f\t\t", thread_local_id, max_value_buf[I0]);}
//
// softmax
// softmax
{
//
// calculate exp for elements
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
accu_value_buf
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
auto
&
xdlops_out
=
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{});
accu_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
static_for
<
0
,
RegSizePerXdlops
,
1
>
{}([
&
](
auto
iK
)
{
// calculate exp for elements
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
math
::
exp
(
xdlops_out
.
template
AsType
<
float
>()[
iK
]
-
max_value_buf
(
I0
));
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
});
auto
&
xdlops_out
=
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{});
});
// sum data
static_for
<
0
,
RegSizePerXdlops
,
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
math
::
exp
(
xdlops_out
.
template
AsType
<
float
>()[
iK
]
-
max_value_buf
(
I0
));
auto
&
xdlops_out
=
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{});
ThreadwiseSumReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
accu_value_buf
);
block_sync_lds
();
});
});
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I0
));
});
// sum data
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
auto
&
xdlops_out
=
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{});
ThreadwiseSumReduce
::
Reduce
(
xdlops_out
.
template
AsType
<
float
>(),
accu_value_buf
);
block_sync_lds
();
block_sync_lds
();
});
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
accu_value_buf
(
I0
));
block_sync_lds
();
// change elements
// change elements
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n
)
{
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
0
,
n
,
0
));
auto
&
xdlops_out
=
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{});
auto
&
xdlops_out
=
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{});
static_for
<
0
,
c_thread_desc
.
GetLength
(
I2
),
1
>
{}([
&
](
auto
iK
)
{
static_for
<
0
,
c_thread_desc
.
GetLength
(
I2
),
1
>
{}([
&
](
auto
iK
)
{
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
xdlops_out
.
template
AsType
<
float
>()(
iK
)
=
xdlops_out
.
template
AsType
<
float
>()[
iK
]
/
accu_value_buf
(
I0
);
xdlops_out
.
template
AsType
<
float
>()[
iK
]
/
accu_value_buf
(
I0
);
});
});
});
}
}
);
}
}
};
// namespace ck
};
// namespace ck
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment