Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
24f99138
Commit
24f99138
authored
Sep 20, 2022
by
wangshaojie6
Browse files
Merge branch 'develop' into att_with_MNKOPadding
parents
31d2d52a
4eba345f
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1439 additions
and
241 deletions
+1439
-241
client_example/05_layernorm/layernorm2d.cpp
client_example/05_layernorm/layernorm2d.cpp
+2
-2
example/27_layernorm/layernorm_blockwise.cpp
example/27_layernorm/layernorm_blockwise.cpp
+23
-20
example/39_permute/CMakeLists.txt
example/39_permute/CMakeLists.txt
+9
-0
example/39_permute/common.hpp
example/39_permute/common.hpp
+468
-0
example/39_permute/permute_1xHxW_fp16.cpp
example/39_permute/permute_1xHxW_fp16.cpp
+20
-0
example/39_permute/permute_HxWx4_fp16.cpp
example/39_permute/permute_HxWx4_fp16.cpp
+22
-0
example/39_permute/permute_NxHxW_fp16.cpp
example/39_permute/permute_NxHxW_fp16.cpp
+20
-0
example/39_permute/run_permute_bundle_example.inc
example/39_permute/run_permute_bundle_example.inc
+78
-0
example/39_permute/run_permute_element_example.inc
example/39_permute/run_permute_element_example.inc
+65
-0
example/42_groupnorm/CMakeLists.txt
example/42_groupnorm/CMakeLists.txt
+1
-0
example/42_groupnorm/groupnorm_sigmoid_fp16.cpp
example/42_groupnorm/groupnorm_sigmoid_fp16.cpp
+172
-0
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+1
-0
include/ck/tensor_operation/gpu/device/device_elementwise.hpp
...ude/ck/tensor_operation/gpu/device/device_elementwise.hpp
+25
-9
include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp
.../ck/tensor_operation/gpu/device/device_layernorm_impl.hpp
+88
-107
include/ck/tensor_operation/gpu/device/device_permute.hpp
include/ck/tensor_operation/gpu/device/device_permute.hpp
+37
-0
include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp
.../tensor_operation/gpu/device/impl/device_permute_impl.hpp
+282
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+15
-0
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
+4
-0
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
..._operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
+55
-53
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp
...peration/gpu/grid/gridwise_layernorm_welford_variance.hpp
+52
-50
No files found.
client_example/05_layernorm/layernorm2d.cpp
View file @
24f99138
...
@@ -81,8 +81,8 @@ int main(int argc, char* argv[])
...
@@ -81,8 +81,8 @@ int main(int argc, char* argv[])
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
M
,
N
},
// lengths
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
({
M
,
N
},
// lengths
{
Stride
,
1
},
// xStrides
{
Stride
,
1
},
// xStrides
{
1
},
// gammaStrides
{
0
,
1
},
// gammaStrides
{
1
},
// betaStrides
{
0
,
1
},
// betaStrides
{
Stride
,
1
},
// yStrides
{
Stride
,
1
},
// yStrides
{
1
},
// reduceDims
{
1
},
// reduceDims
1e-4
,
1e-4
,
...
...
example/27_layernorm/layernorm_blockwise.cpp
View file @
24f99138
...
@@ -29,24 +29,27 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -29,24 +29,27 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr
int
Rank
=
2
;
constexpr
int
Rank
=
2
;
constexpr
int
NumReduceDim
=
1
;
constexpr
int
NumReduceDim
=
1
;
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceLayernormImpl
<
XDataType
,
using
DeviceInstance
=
GammaDataType
,
ck
::
tensor_operation
::
device
::
DeviceLayernormImpl
<
XDataType
,
BetaDataType
,
GammaDataType
,
AccDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
PassThrough
,
YDataType
,
Rank
,
PassThrough
,
NumReduceDim
,
Rank
,
256
,
// BlockSize
NumReduceDim
,
8
,
// ClusterM
256
,
// BlockSize
32
,
// ClusterK
8
,
// ClusterM
1
,
// SliceM
32
,
// ClusterK
8
,
// SliceK
1
,
// SliceM
1
,
// SrcVecDim (0=M, 1=K)
8
,
// SliceK
8
,
// SrcScalarPerVector
1
,
// SrcVecDim (0=M, 1=K)
8
,
// GammaScalarPerVector
8
,
// SrcScalarPerVector
8
,
// BetaScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
8
>
;
// OutScalarPerVector
8
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
8
,
// BetaScalarPerVector
8
>
;
// OutScalarPerVector
int
main
()
int
main
()
{
{
...
@@ -88,8 +91,8 @@ int main()
...
@@ -88,8 +91,8 @@ int main()
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
{
M
,
N
},
{
M
,
N
},
std
::
vector
<
ck
::
index_t
>
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
gamma
.
mDesc
.
GetStrides
().
begin
(),
gamma
.
mDesc
.
GetStrides
().
end
()
},
{
0
,
1
},
std
::
vector
<
ck
::
index_t
>
{
beta
.
mDesc
.
GetStrides
().
begin
(),
beta
.
mDesc
.
GetStrides
().
end
()
},
{
0
,
1
},
std
::
vector
<
ck
::
index_t
>
{
y
.
mDesc
.
GetStrides
().
begin
(),
y
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
y
.
mDesc
.
GetStrides
().
begin
(),
y
.
mDesc
.
GetStrides
().
end
()},
{
1
},
{
1
},
1e-4
,
1e-4
,
...
...
example/39_permute/CMakeLists.txt
0 → 100644
View file @
24f99138
add_custom_target
(
example_permute
)
add_example_executable
(
example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp
)
add_example_executable
(
example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp
)
add_example_executable
(
example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp
)
add_dependencies
(
example_permute example_permute_1xHxW_fp16
)
add_dependencies
(
example_permute example_permute_NxHxW_fp16
)
add_dependencies
(
example_permute example_permute_HxWx4_fp16
)
example/39_permute/common.hpp
0 → 100644
View file @
24f99138
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <iterator>
#include <numeric>
#include <type_traits>
#include <utility>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/utility/type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F64
=
double
;
struct
Problem
final
{
static
constexpr
std
::
size_t
NumDim
=
3
;
using
Shape
=
std
::
array
<
std
::
size_t
,
NumDim
>
;
using
Axes
=
Shape
;
Problem
()
=
delete
;
explicit
Problem
(
const
Shape
&
default_shape
,
const
Axes
&
default_axes
)
:
shape
(
default_shape
),
axes
(
default_axes
)
{
}
Shape
shape
;
Axes
axes
;
};
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
namespace
detail
{
template
<
typename
Array
,
std
::
size_t
Difference
>
struct
enlarge_array_size
;
template
<
typename
T
,
std
::
size_t
Size
,
std
::
size_t
Difference
>
struct
enlarge_array_size
<
std
::
array
<
T
,
Size
>
,
Difference
>
{
using
type
=
std
::
array
<
T
,
Size
+
Difference
>
;
};
template
<
typename
Array
,
std
::
size_t
Difference
>
using
enlarge_array_size_t
=
typename
enlarge_array_size
<
Array
,
Difference
>::
type
;
template
<
typename
Array
>
struct
get_array_size
;
template
<
typename
T
,
std
::
size_t
Size
>
struct
get_array_size
<
std
::
array
<
T
,
Size
>>
:
std
::
integral_constant
<
std
::
size_t
,
Size
>
{
};
template
<
typename
Array
>
inline
constexpr
std
::
size_t
get_array_size_v
=
get_array_size
<
Array
>::
value
;
template
<
typename
T
,
typename
=
void
>
struct
is_iterator
:
std
::
false_type
{
};
template
<
typename
T
>
struct
is_iterator
<
T
,
std
::
void_t
<
decltype
(
*
std
::
declval
<
T
>
()),
decltype
(
++
std
::
declval
<
std
::
add_lvalue_reference_t
<
T
>>
()),
decltype
(
std
::
declval
<
std
::
add_lvalue_reference_t
<
T
>>
()
++
)
>>
:
std
::
true_type
{
};
template
<
typename
T
>
inline
constexpr
bool
is_iterator_v
=
is_iterator
<
T
>::
value
;
struct
Placeholder
final
{
template
<
typename
T
>
constexpr
inline
operator
T
()
const
noexcept
;
};
template
<
typename
Iterator
,
typename
=
void
>
struct
is_output_iterator
:
std
::
false_type
{
};
template
<
typename
Iterator
>
struct
is_output_iterator
<
Iterator
,
std
::
void_t
<
decltype
(
*
std
::
declval
<
Iterator
>
()
=
std
::
declval
<
Placeholder
>
())
>>
:
std
::
bool_constant
<
is_iterator_v
<
Iterator
>>
{
};
template
<
typename
T
>
inline
constexpr
bool
is_output_iterator_v
=
is_output_iterator
<
T
>::
value
;
template
<
typename
Iterator
,
typename
=
void
>
struct
is_bidirectional_iterator
:
std
::
false_type
{
};
template
<
typename
Iterator
>
struct
is_bidirectional_iterator
<
Iterator
,
std
::
void_t
<
decltype
(
--
std
::
declval
<
std
::
add_lvalue_reference_t
<
Iterator
>>
()),
decltype
(
std
::
declval
<
std
::
add_lvalue_reference_t
<
Iterator
>>
()
--
)
>>
:
std
::
bool_constant
<
is_iterator_v
<
Iterator
>>
{
};
template
<
typename
Iterator
>
inline
constexpr
bool
is_bidirectional_iterator_v
=
is_bidirectional_iterator
<
Iterator
>::
value
;
template
<
typename
Iterator
,
typename
=
void
>
struct
is_random_access_iterator
:
std
::
false_type
{
};
template
<
typename
Iterator
>
struct
is_random_access_iterator
<
Iterator
,
std
::
void_t
<
decltype
(
std
::
declval
<
Iterator
>
()
+
1
),
decltype
(
std
::
declval
<
Iterator
>
()
-
1
),
decltype
(
std
::
declval
<
Iterator
>
()[
1
])
>>
:
std
::
bool_constant
<
is_iterator_v
<
Iterator
>>
{
};
template
<
typename
Iterator
>
inline
constexpr
bool
is_random_access_iterator_v
=
is_random_access_iterator
<
Iterator
>::
value
;
template
<
typename
T
,
typename
=
void
>
struct
is_range
:
std
::
false_type
{
};
template
<
typename
T
>
struct
is_range
<
T
,
std
::
void_t
<
decltype
(
begin
(
std
::
declval
<
T
>
())),
decltype
(
end
(
std
::
declval
<
T
>
())),
decltype
(
begin
(
std
::
declval
<
T
>
())
!=
end
(
std
::
declval
<
T
>
()))
>>
:
std
::
bool_constant
<
is_iterator_v
<
ck
::
remove_cvref_t
<
decltype
(
begin
(
std
::
declval
<
T
>
()))
>>>
{
};
template
<
typename
T
>
inline
constexpr
bool
is_range_v
=
is_range
<
T
>::
value
;
template
<
typename
Range
,
typename
=
void
>
struct
is_sized_range
:
std
::
false_type
{
};
template
<
typename
Range
>
struct
is_sized_range
<
Range
,
std
::
void_t
<
decltype
(
size
(
std
::
declval
<
Range
>
()))
>>
:
std
::
bool_constant
<
is_range_v
<
Range
>>
{
};
template
<
typename
Range
>
inline
constexpr
bool
is_sized_range_v
=
is_sized_range
<
Range
>::
value
;
template
<
typename
Range
,
typename
=
void
>
struct
is_bidirectional_range
:
std
::
false_type
{
};
template
<
typename
Range
>
struct
is_bidirectional_range
<
Range
,
std
::
void_t
<>>
:
std
::
bool_constant
<
is_range_v
<
Range
>
&&
is_bidirectional_iterator_v
<
ck
::
remove_cvref_t
<
decltype
(
begin
(
std
::
declval
<
Range
>
()))
>>>
{
};
template
<
typename
Range
>
inline
constexpr
bool
is_bidirectional_range_v
=
is_bidirectional_range
<
Range
>::
value
;
template
<
typename
Range
,
typename
=
void
>
struct
is_random_access_range
:
std
::
false_type
{
};
template
<
typename
Range
>
struct
is_random_access_range
<
Range
,
std
::
void_t
<>>
:
std
::
bool_constant
<
is_range_v
<
Range
>
&&
is_random_access_iterator_v
<
ck
::
remove_cvref_t
<
decltype
(
begin
(
std
::
declval
<
Range
>
()))
>>>
{
};
template
<
typename
Range
>
inline
constexpr
bool
is_random_access_range_v
=
is_random_access_range
<
Range
>::
value
;
template
<
typename
Range
>
class
to_array_proxy
{
static_assert
(
is_range_v
<
Range
>
);
public:
explicit
to_array_proxy
(
const
Range
&
source
)
noexcept
:
source_
(
source
)
{}
template
<
typename
T
,
std
::
size_t
Size
>
operator
std
::
array
<
T
,
Size
>
()
const
{
std
::
array
<
T
,
Size
>
destination
;
std
::
copy_n
(
std
::
begin
(
source_
),
std
::
min
<
std
::
size_t
>
(
Size
,
std
::
size
(
source_
)),
std
::
begin
(
destination
));
return
destination
;
}
private:
const
Range
&
source_
;
};
}
// namespace detail
template
<
typename
Range
>
inline
auto
to_array
(
Range
&
range
)
noexcept
->
std
::
enable_if_t
<
detail
::
is_range_v
<
Range
>
,
detail
::
to_array_proxy
<
ck
::
remove_cvref_t
<
Range
>>>
{
return
detail
::
to_array_proxy
<
ck
::
remove_cvref_t
<
Range
>>
{
range
};
}
namespace
ranges
{
template
<
typename
InputRange
,
typename
OutputIterator
>
inline
auto
copy
(
InputRange
&&
range
,
OutputIterator
iter
)
->
decltype
(
std
::
copy
(
std
::
begin
(
std
::
forward
<
InputRange
>
(
range
)),
std
::
end
(
std
::
forward
<
InputRange
>
(
range
)),
iter
))
{
return
std
::
copy
(
std
::
begin
(
std
::
forward
<
InputRange
>
(
range
)),
std
::
end
(
std
::
forward
<
InputRange
>
(
range
)),
iter
);
}
}
// namespace ranges
template
<
typename
Axes
>
inline
auto
is_valid_axes
(
const
Axes
&
axes
)
->
std
::
enable_if_t
<
detail
::
is_random_access_range_v
<
Axes
>
,
bool
>
{
using
std
::
empty
;
if
(
empty
(
axes
))
{
return
false
;
}
using
std
::
begin
,
std
::
end
;
std
::
vector
<
std
::
size_t
>
sorted_axes
(
begin
(
axes
),
end
(
axes
));
std
::
sort
(
begin
(
sorted_axes
),
end
(
sorted_axes
));
const
auto
last
=
std
::
unique
(
begin
(
sorted_axes
),
end
(
sorted_axes
));
return
(
last
==
end
(
sorted_axes
))
&&
(
*
begin
(
sorted_axes
)
==
0
)
&&
(
*
std
::
prev
(
last
)
==
size
(
axes
)
-
1
);
}
template
<
typename
Shape
>
inline
auto
is_valid_shape
(
const
Shape
&
shape
)
->
std
::
enable_if_t
<
detail
::
is_range_v
<
Shape
>
,
bool
>
{
static_assert
(
std
::
is_unsigned_v
<
ck
::
remove_cvref_t
<
decltype
(
*
std
::
begin
(
shape
))
>>
);
using
std
::
begin
,
std
::
end
;
using
std
::
empty
;
return
!
empty
(
shape
)
&&
std
::
all_of
(
begin
(
shape
),
end
(
shape
),
[](
auto
dim
)
{
return
0
<
dim
;
});
}
template
<
typename
Shape
,
typename
Indices
>
inline
auto
is_valid_indices
(
const
Shape
&
shape
,
const
Indices
&
indices
)
->
std
::
enable_if_t
<
detail
::
is_sized_range_v
<
Shape
>
&&
detail
::
is_sized_range_v
<
Indices
>
,
bool
>
{
static_assert
(
std
::
is_unsigned_v
<
ck
::
remove_cvref_t
<
decltype
(
*
std
::
begin
(
indices
))
>>
);
if
(
!
is_valid_shape
(
shape
))
{
return
false
;
}
using
std
::
empty
;
if
(
empty
(
indices
))
{
return
false
;
}
using
std
::
size
;
if
(
size
(
shape
)
!=
size
(
indices
))
{
return
false
;
}
using
std
::
begin
,
std
::
end
;
auto
dim
=
begin
(
shape
);
auto
idx
=
begin
(
indices
);
for
(;
dim
!=
end
(
shape
)
&&
idx
!=
end
(
indices
);
++
dim
,
++
idx
)
{
if
(
*
dim
<=
*
idx
)
{
return
false
;
}
}
return
true
;
}
template
<
std
::
size_t
Size
>
std
::
array
<
std
::
size_t
,
Size
>
transpose
(
const
std
::
array
<
std
::
size_t
,
Size
>&
shape
,
const
std
::
array
<
std
::
size_t
,
Size
>&
axes
)
{
assert
(
is_valid_shape
(
shape
)
&&
is_valid_axes
(
axes
));
std
::
array
<
std
::
size_t
,
Size
>
transposed
;
auto
iter
=
std
::
begin
(
transposed
);
for
(
const
auto
axis
:
axes
)
{
*
iter
++
=
shape
[
axis
];
}
return
transposed
;
}
auto
extend_shape
(
const
Problem
::
Shape
&
shape
,
std
::
size_t
new_dim
)
{
detail
::
enlarge_array_size_t
<
Problem
::
Shape
,
1
>
extended_shape
;
using
std
::
begin
,
std
::
end
;
std
::
copy
(
begin
(
shape
),
end
(
shape
),
begin
(
extended_shape
));
extended_shape
.
back
()
=
new_dim
;
return
extended_shape
;
}
auto
extend_axes
(
const
Problem
::
Axes
&
axes
)
{
detail
::
enlarge_array_size_t
<
Problem
::
Axes
,
1
>
extended_axes
;
using
std
::
begin
,
std
::
end
;
std
::
copy
(
begin
(
axes
),
end
(
axes
),
begin
(
extended_axes
));
extended_axes
.
back
()
=
detail
::
get_array_size_v
<
Problem
::
Axes
>
;
return
extended_axes
;
}
template
<
typename
Shape
,
typename
Indices
>
auto
advance_indices
(
const
Shape
&
shape
,
Indices
&
indices
)
->
std
::
enable_if_t
<
detail
::
is_bidirectional_range_v
<
Shape
>
&&
detail
::
is_sized_range_v
<
Shape
>
&&
detail
::
is_bidirectional_range_v
<
Indices
>
&&
detail
::
is_sized_range_v
<
Indices
>
,
bool
>
{
using
std
::
size
;
if
(
!
(
is_valid_shape
(
shape
)
&&
is_valid_indices
(
shape
,
indices
)
&&
size
(
shape
)
==
size
(
indices
)))
{
return
false
;
}
bool
carry
=
true
;
using
std
::
rbegin
,
std
::
rend
;
auto
dim
=
rbegin
(
shape
);
auto
idx
=
rbegin
(
indices
);
for
(;
carry
&&
dim
!=
rend
(
shape
)
&&
idx
!=
rend
(
indices
);
++
dim
,
++
idx
)
{
*
idx
=
(
*
idx
+
carry
);
carry
=
((
*
idx
==
*
dim
)
?
(
*
idx
=
0
,
true
)
:
false
);
}
return
!
carry
;
}
template
<
typename
Src
,
typename
Axes
,
typename
Functor
,
typename
Dest
>
auto
host_permute
(
const
Tensor
<
Src
>&
src
,
const
Axes
&
axes
,
Functor
functor
,
Tensor
<
Dest
>&
dest
)
->
std
::
enable_if_t
<
detail
::
is_random_access_range_v
<
Axes
>
&&
detail
::
is_sized_range_v
<
Axes
>
&&
std
::
is_invocable_v
<
Functor
,
std
::
add_lvalue_reference_t
<
Dest
>
,
std
::
add_lvalue_reference_t
<
Src
>>
,
bool
>
{
const
auto
&
shape
=
src
.
mDesc
.
GetLengths
();
const
auto
&
transposed_shape
=
dest
.
mDesc
.
GetLengths
();
if
(
!
(
is_valid_shape
(
shape
)
&&
is_valid_shape
(
transposed_shape
)))
{
return
false
;
}
using
std
::
size
;
if
(
!
is_valid_axes
(
axes
))
{
return
false
;
}
static_assert
(
detail
::
is_sized_range_v
<
ck
::
remove_cvref_t
<
decltype
(
shape
)
>>
&&
detail
::
is_sized_range_v
<
ck
::
remove_cvref_t
<
decltype
(
transposed_shape
)
>>
);
if
(
size
(
shape
)
!=
size
(
transposed_shape
))
{
return
false
;
}
static_assert
(
detail
::
is_random_access_range_v
<
ck
::
remove_cvref_t
<
decltype
(
shape
)
>>
&&
detail
::
is_random_access_range_v
<
ck
::
remove_cvref_t
<
decltype
(
transposed_shape
)
>>
);
{
for
(
std
::
size_t
idx
=
0
;
idx
<
size
(
shape
);
++
idx
)
{
if
(
transposed_shape
[
idx
]
!=
shape
[
axes
[
idx
]])
{
return
false
;
}
}
}
std
::
vector
<
std
::
size_t
>
indices
(
size
(
shape
),
0
);
if
(
!
is_valid_indices
(
shape
,
indices
))
{
return
false
;
}
switch
(
size
(
shape
))
{
case
3
:
{
do
{
Dest
output
=
0
;
functor
(
output
,
src
(
indices
[
0
],
indices
[
1
],
indices
[
2
]));
dest
(
indices
[
axes
[
0
]],
indices
[
axes
[
1
]],
indices
[
axes
[
2
]])
=
output
;
}
while
(
advance_indices
(
shape
,
indices
));
}
break
;
case
4
:
{
do
{
Dest
output
=
0
;
functor
(
output
,
src
(
indices
[
0
],
indices
[
1
],
indices
[
2
],
indices
[
3
]));
dest
(
indices
[
axes
[
0
]],
indices
[
axes
[
1
]],
indices
[
axes
[
2
]],
indices
[
axes
[
3
]])
=
output
;
}
while
(
advance_indices
(
shape
,
indices
));
}
break
;
default:
return
false
;
}
return
true
;
}
example/39_permute/permute_1xHxW_fp16.cpp
0 → 100644
View file @
24f99138
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using
InDataType
=
F16
;
using
OutDataType
=
F16
;
// clang-format off
using
DevicePermuteInstance
=
ck
::
tensor_operation
::
device
::
DevicePermuteImpl
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | |
<
3
,
InDataType
,
OutDataType
,
PassThrough
,
256
,
1
,
32
,
32
,
3
,
S
<
1
,
32
,
8
>
,
S
<
0
,
1
,
2
>
,
2
,
1
,
2
,
1
>
;
// clang-format on
#include "run_permute_element_example.inc"
int
main
()
{
return
!
run_permute_element_example
({
1
,
32000
,
80
},
{
0
,
2
,
1
});
}
example/39_permute/permute_HxWx4_fp16.cpp
0 → 100644
View file @
24f99138
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using
DataType
=
F16
;
using
BundleType
=
F64
;
static_assert
(
sizeof
(
BundleType
)
%
sizeof
(
DataType
)
==
0
);
// clang-format off
using
DevicePermuteInstance
=
ck
::
tensor_operation
::
device
::
DevicePermuteImpl
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | |
<
3
,
BundleType
,
BundleType
,
PassThrough
,
256
,
1
,
32
,
32
,
5
,
S
<
1
,
32
,
8
>
,
S
<
0
,
1
,
2
>
,
2
,
1
,
4
,
1
>
;
// clang-format on
#include "run_permute_bundle_example.inc"
int
main
()
{
return
!
run_permute_bundle_example
({
1
,
80
,
32000
},
{
0
,
2
,
1
});
}
example/39_permute/permute_NxHxW_fp16.cpp
0 → 100644
View file @
24f99138
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
using
InDataType
=
F16
;
using
OutDataType
=
F16
;
// clang-format off
using
DevicePermuteInstance
=
ck
::
tensor_operation
::
device
::
DevicePermuteImpl
// ######| NumDim| InData| OutData| Elementwise| Block| NPer| HPer| WPer| InBlock| InBlockTransfer| InBlockTransfer| Src| Dst| Src| Dst|
// ######| | Type| Type| Operation| Size| Block| Block| Block| LdsExtraW| ThreadClusterLengths| ThreadClusterArrangeOrder| VectorDim| VectorDim| ScalarPerVector| ScalarPerVector|
// ######| | | | | | | | | | | | | | | |
// ######| | | | | | | | | | | | | | | |
<
3
,
InDataType
,
OutDataType
,
PassThrough
,
128
,
4
,
16
,
8
,
6
,
S
<
2
,
16
,
4
>
,
S
<
0
,
1
,
2
>
,
2
,
1
,
2
,
1
>
;
// clang-format on
#include "run_permute_element_example.inc"
int
main
()
{
return
!
run_permute_element_example
({
121
,
768
,
80
},
{
0
,
2
,
1
});
}
example/39_permute/run_permute_bundle_example.inc
0 → 100644
View file @
24f99138
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool
run_permute_bundle
(
const
Problem
&
problem
)
{
const
auto
&
input_bundle_shape
=
problem
.
shape
;
const
auto
&
input_bundle_axes
=
problem
.
axes
;
const
auto
output_bundle_shape
=
transpose
(
input_bundle_shape
,
input_bundle_axes
);
Tensor
<
BundleType
>
input_bundle_tensor
(
input_bundle_shape
);
Tensor
<
BundleType
>
output_bundle_tensor
(
output_bundle_shape
);
// initialize tensor by assigning DataType values
ck
::
utils
::
FillUniformDistribution
<
DataType
>
{
-
1.
f
,
1.
f
}(
input_bundle_tensor
.
AsSpan
<
DataType
>
());
DeviceMem
input_device_buf
(
input_bundle_tensor
.
GetElementSpaceSizeInBytes
());
DeviceMem
output_device_buf
(
output_bundle_tensor
.
GetElementSpaceSizeInBytes
());
using
std
::
data
;
input_device_buf
.
ToDevice
(
data
(
input_bundle_tensor
));
static_assert
(
std
::
is_default_constructible_v
<
DevicePermuteInstance
>
);
auto
permute
=
DevicePermuteInstance
{};
auto
argument
=
permute
.
MakeArgument
(
to_array
(
input_bundle_shape
),
to_array
(
input_bundle_tensor
.
GetStrides
()),
to_array
(
output_bundle_shape
),
to_array
(
output_bundle_tensor
.
GetStrides
()),
input_device_buf
.
GetDeviceBuffer
(),
output_device_buf
.
GetDeviceBuffer
(),
PassThrough
{});
if
(
!
permute
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
"The runtime parameters seems not supported by the device instance, exiting!"
<<
std
::
endl
;
return
false
;
};
auto
invoker
=
permute
.
MakeInvoker
();
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
true
});
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
output_device_buf
.
FromDevice
(
data
(
output_bundle_tensor
));
constexpr
std
::
size_t
NumElemsInBundle
=
sizeof
(
BundleType
)
/
sizeof
(
DataType
);
// extend tensor shape from [N, H, W] to [N, H, W, NumElemsInBundle]
// axes from [0, 2, 1] to [0, 2, 1, 3]
const
auto
input_shape
=
extend_shape
(
input_bundle_shape
,
NumElemsInBundle
);
const
auto
input_axes
=
extend_axes
(
input_bundle_axes
);
using
std
::
begin
;
Tensor
<
DataType
>
input_tensor
(
input_shape
);
ranges
::
copy
(
input_bundle_tensor
.
AsSpan
<
const
DataType
>
(),
begin
(
input_tensor
));
Tensor
<
DataType
>
output_tensor
(
transpose
(
input_shape
,
input_axes
));
if
(
!
host_permute
(
input_tensor
,
input_axes
,
PassThrough
{},
output_tensor
))
{
return
false
;
}
return
ck
::
utils
::
check_err
(
output_bundle_tensor
.
AsSpan
<
const
DataType
>
(),
output_tensor
.
AsSpan
<
const
DataType
>
(),
"Error: incorrect results in output tensor"
,
1
e
-
6
,
1
e
-
6
);
}
bool
run_permute_bundle_example
(
const
Problem
::
Shape
&
shape
,
const
Problem
::
Axes
&
axes
)
{
return
run_permute_bundle
(
Problem
{
shape
,
axes
});
}
example/39_permute/run_permute_element_example.inc
0 → 100644
View file @
24f99138
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool
run_permute_element
(
const
Problem
&
problem
)
{
const
auto
&
input_shape
=
problem
.
shape
;
const
auto
&
input_axes
=
problem
.
axes
;
const
auto
output_shape
=
transpose
(
input_shape
,
input_axes
);
Tensor
<
InDataType
>
input_tensor
(
input_shape
);
Tensor
<
OutDataType
>
output_tensor
(
output_shape
);
ck
::
utils
::
FillUniformDistribution
<
InDataType
>
{
-
1.
f
,
1.
f
}(
input_tensor
);
DeviceMem
input_device_buf
(
input_tensor
.
GetElementSpaceSizeInBytes
());
DeviceMem
output_device_buf
(
output_tensor
.
GetElementSpaceSizeInBytes
());
using
std
::
data
;
input_device_buf
.
ToDevice
(
data
(
input_tensor
));
static_assert
(
std
::
is_default_constructible_v
<
DevicePermuteInstance
>
);
auto
permute
=
DevicePermuteInstance
{};
auto
argument
=
permute
.
MakeArgument
(
to_array
(
input_shape
),
to_array
(
input_tensor
.
GetStrides
()),
to_array
(
output_shape
),
to_array
(
output_tensor
.
GetStrides
()),
input_device_buf
.
GetDeviceBuffer
(),
output_device_buf
.
GetDeviceBuffer
(),
PassThrough
{});
if
(
!
permute
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
"The runtime parameters seems not supported by the device instance, exiting!"
<<
std
::
endl
;
return
false
;
};
auto
invoker
=
permute
.
MakeInvoker
();
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
true
});
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
output_device_buf
.
FromDevice
(
data
(
output_tensor
));
Tensor
<
OutDataType
>
output_tensor_host
(
output_shape
);
if
(
!
host_permute
(
input_tensor
,
input_axes
,
PassThrough
{},
output_tensor_host
))
{
return
false
;
}
return
ck
::
utils
::
check_err
(
output_tensor
.
AsSpan
<
const
OutDataType
>
(),
output_tensor_host
.
AsSpan
<
const
OutDataType
>
(),
"Error: incorrect results in output tensor"
,
1
e
-
6
,
1
e
-
6
);
}
bool
run_permute_element_example
(
const
Problem
::
Shape
&
shape
,
const
Problem
::
Axes
&
axes
)
{
return
run_permute_element
(
Problem
{
shape
,
axes
});
}
example/42_groupnorm/CMakeLists.txt
0 → 100644
View file @
24f99138
add_example_executable
(
example_groupnorm_sigmoid_fp16 groupnorm_sigmoid_fp16.cpp
)
example/42_groupnorm/groupnorm_sigmoid_fp16.cpp
0 → 100644
View file @
24f99138
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/device_layernorm_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp"
constexpr
int
Rank
=
5
;
constexpr
int
NumReduceDim
=
3
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
struct
YElementOp
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck
::
is_same
<
T
,
float
>::
value
||
ck
::
is_same
<
T
,
double
>::
value
||
ck
::
is_same
<
T
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
T
a
;
ck
::
tensor_operation
::
element_wise
::
Sigmoid
{}(
a
,
x
);
y
=
x
*
a
;
};
};
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceLayernormImpl
<
XDataType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
YDataType
,
YElementOp
,
Rank
,
NumReduceDim
,
256
,
// BlockSize
8
,
// ClusterM
32
,
// ClusterK
1
,
// SliceM
8
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
8
,
// SrcScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
8
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
8
,
// BetaScalarPerVector
8
>
;
// OutScalarPerVector
int
main
(
int
argc
,
char
*
argv
[])
{
ck
::
index_t
N
=
128
;
ck
::
index_t
H
=
16
;
ck
::
index_t
W
=
16
;
ck
::
index_t
G
=
32
;
ck
::
index_t
C
=
40
;
if
(
argc
==
1
)
{
// use default case
}
else
if
(
argc
==
6
)
{
N
=
std
::
stoi
(
argv
[
1
]);
H
=
std
::
stoi
(
argv
[
2
]);
W
=
std
::
stoi
(
argv
[
3
]);
G
=
std
::
stoi
(
argv
[
4
]);
C
=
std
::
stoi
(
argv
[
5
]);
}
else
{
std
::
cerr
<<
"arg1 to 5: N, H, W, G, C"
<<
std
::
endl
;
return
1
;
}
Tensor
<
XDataType
>
x
({
N
,
H
,
W
,
G
,
C
});
Tensor
<
YDataType
>
y
({
N
,
H
,
W
,
G
,
C
});
Tensor
<
GammaDataType
>
gamma
({
G
,
C
});
Tensor
<
BetaDataType
>
beta
({
G
,
C
});
ck
::
utils
::
FillUniformDistribution
<
XDataType
>
{
0.
f
,
1.
f
}(
x
.
begin
(),
x
.
end
());
ck
::
utils
::
FillUniformDistribution
<
GammaDataType
>
{
0.
f
,
1.
f
}(
gamma
.
begin
(),
gamma
.
end
());
ck
::
utils
::
FillUniformDistribution
<
BetaDataType
>
{
0.
f
,
1.
f
}(
beta
.
begin
(),
beta
.
end
());
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
beta_dev
(
sizeof
(
BetaDataType
)
*
beta
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_dev
(
sizeof
(
YDataType
)
*
y
.
mDesc
.
GetElementSpaceSize
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
gamma_dev
.
ToDevice
(
gamma
.
mData
.
data
());
beta_dev
.
ToDevice
(
beta
.
mData
.
data
());
const
auto
y_element_op
=
YElementOp
{};
auto
device_instance
=
DeviceInstance
{};
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
{
N
,
H
,
W
,
G
,
C
},
std
::
vector
<
ck
::
index_t
>
{
x
.
mDesc
.
GetStrides
().
begin
(),
x
.
mDesc
.
GetStrides
().
end
()},
{
0
,
0
,
0
,
C
,
1
},
{
0
,
0
,
0
,
C
,
1
},
std
::
vector
<
ck
::
index_t
>
{
y
.
mDesc
.
GetStrides
().
begin
(),
y
.
mDesc
.
GetStrides
().
end
()},
{
1
,
2
,
4
},
// reduction dimension: [H, W, C]
1e-6
,
x_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
(),
y_dev
.
GetDeviceBuffer
(),
y_element_op
);
if
(
!
device_instance
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
std
::
cout
<<
"The runtime parameters are not supported"
<<
std
::
endl
;
return
1
;
};
auto
invoker_ptr
=
device_instance
.
MakeInvokerPointer
();
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
,
true
});
std
::
size_t
num_btype
=
sizeof
(
XDataType
)
*
N
*
H
*
W
*
G
*
C
+
sizeof
(
YDataType
)
*
N
*
H
*
W
*
G
*
C
+
sizeof
(
GammaDataType
)
*
G
*
C
+
sizeof
(
BetaDataType
)
*
G
*
C
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
device_instance
.
GetTypeString
()
<<
std
::
endl
;
bool
pass
=
true
;
{
Tensor
<
YDataType
>
host_y
({
N
,
H
,
W
,
G
,
C
});
using
ReferenceInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGroupnorm
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
YElementOp
>
;
ReferenceInstance
ref
;
auto
ref_argument
=
ref
.
MakeArgument
(
x
,
gamma
,
beta
,
host_y
,
y_element_op
,
{
N
,
H
,
W
,
G
,
C
},
1e-6
);
auto
ref_invoker
=
ref
.
MakeInvoker
();
ref_invoker
.
Run
(
ref_argument
);
y_dev
.
FromDevice
(
y
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
y
.
mData
,
host_y
.
mData
,
"Error: Incorrect results"
,
1e-3
,
1e-3
);
}
return
(
pass
?
0
:
1
);
}
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
24f99138
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#pragma once
#pragma once
#include <cmath>
#include <string>
#include <string>
#include "ck/stream_config.hpp"
#include "ck/stream_config.hpp"
...
...
include/ck/tensor_operation/gpu/device/device_elementwise.hpp
View file @
24f99138
...
@@ -222,14 +222,9 @@ struct DeviceElementwise
...
@@ -222,14 +222,9 @@ struct DeviceElementwise
}
}
};
};
bool
IsSupportedArgument
(
const
Base
Argument
*
p_arg
)
override
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg
.
lengths_
.
back
()
%
MPerThread
!=
0
)
if
(
pArg
==
nullptr
)
return
false
;
if
(
pArg
->
lengths_
.
back
()
%
MPerThread
!=
0
)
return
false
;
return
false
;
auto
IsScalarPerVectorValid
=
[
&
](
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
auto
IsScalarPerVectorValid
=
[
&
](
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
...
@@ -247,19 +242,40 @@ struct DeviceElementwise
...
@@ -247,19 +242,40 @@ struct DeviceElementwise
bool
valid
=
true
;
bool
valid
=
true
;
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
if
(
!
IsScalarPerVectorValid
(
if
(
!
IsScalarPerVectorValid
(
pArg
->
lengths_
,
pArg
->
inStridesArray_
[
I
.
value
],
InScalarPerVectorSeq
::
At
(
I
)))
arg
.
lengths_
,
arg
.
inStridesArray_
[
I
.
value
],
InScalarPerVectorSeq
::
At
(
I
)))
valid
=
false
;
valid
=
false
;
});
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
if
(
!
IsScalarPerVectorValid
(
if
(
!
IsScalarPerVectorValid
(
pArg
->
lengths_
,
pArg
->
outStridesArray_
[
I
.
value
],
OutScalarPerVectorSeq
::
At
(
I
)))
arg
.
lengths_
,
arg
.
outStridesArray_
[
I
.
value
],
OutScalarPerVectorSeq
::
At
(
I
)))
valid
=
false
;
valid
=
false
;
});
});
return
valid
;
return
valid
;
};
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
)
{
return
Argument
{
lengths
,
inStridesArray
,
outStridesArray
,
in_dev_buffers
,
out_dev_buffers
,
elementwise_op
};
}
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
...
...
include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp
View file @
24f99138
...
@@ -23,11 +23,10 @@ template <typename GridwiseReduction,
...
@@ -23,11 +23,10 @@ template <typename GridwiseReduction,
typename
YDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
>
typename
GridDesc_K
>
__global__
void
kernel_layernorm
(
const
GridDesc_M_K
x_grid_desc_m_k
,
__global__
void
kernel_layernorm
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_K
gamma_grid_desc_k
,
const
GridDesc_
M_
K
gamma_grid_desc_
m_
k
,
const
GridDesc_K
beta_grid_desc_k
,
const
GridDesc_
M_
K
beta_grid_desc_
m_
k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
AccDataType
epsilon
,
...
@@ -38,8 +37,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
...
@@ -38,8 +37,8 @@ __global__ void kernel_layernorm(const GridDesc_M_K x_grid_desc_m_k,
const
AccElementwiseOperation
acc_elementwise_op
)
const
AccElementwiseOperation
acc_elementwise_op
)
{
{
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
gamma_grid_desc_k
,
gamma_grid_desc_
m_
k
,
beta_grid_desc_k
,
beta_grid_desc_
m_
k
,
y_grid_desc_m_k
,
y_grid_desc_m_k
,
num_k_block_tile_iteration
,
num_k_block_tile_iteration
,
epsilon
,
epsilon
,
...
@@ -71,7 +70,9 @@ template <typename XDataType,
...
@@ -71,7 +70,9 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XYSrcVectorDim
,
index_t
XYSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
>
index_t
YDstVectorSize
>
struct
DeviceLayernormImpl
:
public
DeviceLayernorm
<
XDataType
,
struct
DeviceLayernormImpl
:
public
DeviceLayernorm
<
XDataType
,
...
@@ -84,11 +85,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -84,11 +85,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
NumReduceDim
>
NumReduceDim
>
{
{
static_assert
(
static_assert
(
(
KThreadSliceSize
%
GammaSrcVectorSize
==
0
),
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
%
GammaSrcVectorSize
==
0
)
||
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
%
GammaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
static_assert
(
(
KThreadSliceSize
%
BetaSrcVectorSize
==
0
),
((
BetaSrcVectorDim
==
0
&&
MThreadSliceSize
%
BetaSrcVectorSize
==
0
)
||
(
BetaSrcVectorDim
==
1
&&
KThreadSliceSize
%
BetaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
"Invalid thread slice sizes and/or beta vector sizes configuration, please check!"
);
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
...
@@ -162,38 +165,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -162,38 +165,7 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
return
(
in_grid_desc_m_k_padded
);
return
(
in_grid_desc_m_k_padded
);
};
};
static
auto
MakeAffine1dDescriptor
(
const
std
::
vector
<
index_t
>&
Lengths
,
const
std
::
vector
<
index_t
>&
Strides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
const
auto
tupleLengths
=
make_tuple_from_array
(
Lengths
,
Number
<
NumReduceDim
>
{});
const
auto
tupleStrides
=
make_tuple_from_array
(
Strides
,
Number
<
NumReduceDim
>
{});
auto
desc
=
make_naive_tensor_descriptor
(
tupleLengths
,
tupleStrides
);
auto
grid_desc_k
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
NumReduceDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
reduceTotalLength
=
grid_desc_k
.
GetLength
(
Number
<
0
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
Pad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceTotalLength
;
auto
grid_desc_k_padded
=
transform_tensor_descriptor
(
grid_desc_k
,
make_tuple
(
make_right_pad_transform
(
reduceTotalLength
,
Pad_K
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
(
grid_desc_k_padded
);
};
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridDesc_K
=
decltype
(
MakeAffine1dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseReduceLayernormGeneric
=
using
GridwiseReduceLayernormGeneric
=
GridwiseLayernormWelfordVariance_mk_to_mk
<
XDataType
,
GridwiseLayernormWelfordVariance_mk_to_mk
<
XDataType
,
...
@@ -203,7 +175,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -203,7 +175,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
AccDataType
,
AccDataType
,
AccElementwiseOperation
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -211,12 +182,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -211,12 +182,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
KThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XYSrcVectorDim
,
XSrcVectorSize
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
XYSrcVectorDim
,
YDstVectorSize
,
YDstVectorSize
,
false
>
;
false
>
;
using
GridwiseReduceLayernormSweepOnce
=
using
GridwiseReduceLayernormSweepOnce
=
GridwiseLayernormWelfordVariance_mk_to_mk
<
XDataType
,
GridwiseLayernormWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
GammaDataType
,
...
@@ -225,7 +197,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -225,7 +197,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
AccDataType
,
AccDataType
,
AccElementwiseOperation
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
BlockSize
,
MThreadClusterSize
,
MThreadClusterSize
,
KThreadClusterSize
,
KThreadClusterSize
,
...
@@ -233,7 +204,9 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -233,7 +204,9 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
KThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XYSrcVectorDim
,
XSrcVectorSize
,
XSrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
XYSrcVectorDim
,
YDstVectorSize
,
YDstVectorSize
,
...
@@ -258,13 +231,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -258,13 +231,13 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
p_gamma_
(
p_gamma
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
p_y_
(
p_y
),
gammaStrides_
(
gammaStrides
),
betaStrides_
(
betaStrides
),
acc_elementwise_op_
(
acc_elementwise_op
)
acc_elementwise_op_
(
acc_elementwise_op
)
{
{
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
betaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
betaStrides
,
reduceDims
);
long_index_t
invariant_total_length
;
long_index_t
invariant_total_length
;
long_index_t
reduce_total_length
;
long_index_t
reduce_total_length
;
...
@@ -278,12 +251,17 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -278,12 +251,17 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
gridSize_
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
gridSize_
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
*
blkGroupSize_
;
M_BlockTileSize
*
blkGroupSize_
;
reduceLengths_
.
resize
(
NumReduceDim
);
x_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
xStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
for
(
int
i
=
0
;
i
<
NumReduceDim
;
++
i
)
gamma_grid_desc_m_k_
=
{
MakeSrc2dDescriptor
(
Lengths_
,
gammaStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
reduceLengths_
[
i
]
=
lengths
[
reduceDims
[
i
]];
beta_grid_desc_m_k_
=
}
MakeSrc2dDescriptor
(
Lengths_
,
betaStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
y_grid_desc_m_k_
=
MakeSrc2dDescriptor
(
Lengths_
,
yStrides_
,
blkGroupSize_
,
numBlockTileIteration_
);
isSweeponce_
=
x_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
}
}
AccDataType
epsilon_
;
AccDataType
epsilon_
;
...
@@ -295,7 +273,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -295,7 +273,6 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
reduceLengths_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
...
@@ -305,46 +282,35 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -305,46 +282,35 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
int
blkGroupSize_
;
int
blkGroupSize_
;
int
numBlockTileIteration_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
size_t
gridSize_
;
GridDesc_M_K
x_grid_desc_m_k_
;
GridDesc_M_K
gamma_grid_desc_m_k_
;
GridDesc_M_K
beta_grid_desc_m_k_
;
GridDesc_M_K
y_grid_desc_m_k_
;
bool
isSweeponce_
;
};
};
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
const
auto
x_grid_desc_m_k
=
MakeSrc2dDescriptor
(
const
auto
kernel_main
=
arg
.
isSweeponce_
arg
.
Lengths_
,
arg
.
xStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
?
kernel_layernorm
<
GridwiseReduceLayernormSweepOnce
,
const
auto
gamma_grid_desc_k
=
MakeAffine1dDescriptor
(
arg
.
reduceLengths_
,
XDataType
,
arg
.
gammaStrides_
,
GammaDataType
,
arg
.
blkGroupSize_
,
BetaDataType
,
arg
.
numBlockTileIteration_
);
YDataType
,
const
auto
beta_grid_desc_k
=
MakeAffine1dDescriptor
(
arg
.
reduceLengths_
,
AccDataType
,
arg
.
betaStrides_
,
AccElementwiseOperation
,
arg
.
blkGroupSize_
,
GridDesc_M_K
>
arg
.
numBlockTileIteration_
);
:
kernel_layernorm
<
GridwiseReduceLayernormGeneric
,
const
auto
y_grid_desc_m_k
=
MakeSrc2dDescriptor
(
XDataType
,
arg
.
Lengths_
,
arg
.
yStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
GammaDataType
,
BetaDataType
,
bool
sweep_once
=
YDataType
,
x_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
AccDataType
,
AccElementwiseOperation
,
const
auto
kernel_main
=
sweep_once
?
kernel_layernorm
<
GridwiseReduceLayernormSweepOnce
,
GridDesc_M_K
>
;
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
>
:
kernel_layernorm
<
GridwiseReduceLayernormGeneric
,
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
>
;
float
avg_time
=
0
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
avg_time
+=
launch_and_time_kernel
(
stream_config
,
...
@@ -352,10 +318,10 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -352,10 +318,10 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
gridSize_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
x_grid_desc_m_k
,
arg
.
x_grid_desc_m_k
_
,
gamma_grid_desc_
k
,
arg
.
gamma_grid_desc_
m_k_
,
beta_grid_desc_
k
,
arg
.
beta_grid_desc_
m_k_
,
y_grid_desc_m_k
,
arg
.
y_grid_desc_m_k
_
,
arg
.
numBlockTileIteration_
,
arg
.
numBlockTileIteration_
,
arg
.
epsilon_
,
arg
.
epsilon_
,
arg
.
p_x_
,
arg
.
p_x_
,
...
@@ -409,26 +375,41 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
...
@@ -409,26 +375,41 @@ struct DeviceLayernormImpl : public DeviceLayernorm<XDataType,
return
false
;
return
false
;
}
}
if
(
p_arg_
->
gammaStrides_
.
size
()
!=
NumReduceDim
||
// if fastest dim is not reduced
p_arg_
->
betaStrides_
.
size
()
!=
NumReduceDim
)
if
constexpr
(
GammaSrcVectorDim
==
0
)
return
false
;
{
if
(
p_arg_
->
gammaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
auto
IsScalarPerVectorValid
=
[](
bool
isLastDimensionCoalesced
,
int
scalarPerVector
)
{
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
bool
ret
=
true
;
return
(
false
);
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
gammaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
!
isLastDimensionCoalesced
)
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
GammaSrcVectorSize
!=
0
)
ret
=
scalarPerVector
==
1
;
return
(
false
);
else
}
ret
=
KThreadSliceSize
%
scalarPerVector
==
0
;
return
ret
;
// if fastest dim is not reduced
};
if
constexpr
(
BetaSrcVectorDim
==
0
)
{
if
(
p_arg_
->
betaStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
(
false
);
if
(
!
IsScalarPerVectorValid
(
p_arg_
->
gammaStrides_
.
back
()
==
1
,
GammaSrcVectorSize
))
if
(
p_arg_
->
invariant_lowest_length
%
BetaSrcVectorSize
!=
0
)
return
false
;
return
(
false
);
}
else
// if fastest dim is reduced
{
if
(
p_arg_
->
betaStrides_
[
Rank
-
1
]
!=
1
)
return
(
false
);
if
(
!
IsScalarPerVectorValid
(
p_arg_
->
betaStrides_
.
back
()
==
1
,
BetaSrcVectorSize
))
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
BetaSrcVectorSize
!=
0
)
return
false
;
return
(
false
);
}
return
true
;
return
true
;
};
};
...
...
include/ck/tensor_operation/gpu/device/device_permute.hpp
0 → 100644
View file @
24f99138
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <cmath>
#include <memory>
#include <type_traits>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
NumDim
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
>
struct
DevicePermute
:
BaseOperator
{
using
Lengths
=
std
::
array
<
index_t
,
NumDim
>
;
using
Strides
=
Lengths
;
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
Lengths
&
in_lengths
,
const
Strides
&
in_strides
,
const
Lengths
&
out_lengths
,
const
Strides
&
out_strides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp
0 → 100644
View file @
24f99138
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include <utility>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_permute.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Swap last 2 dimensions
// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
// ^^^^^^^^^^^
// output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]]
// ^^^^^^^^^^^
template
<
index_t
NumDim
,
typename
InDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
index_t
BlockSize
,
index_t
NPerBlock
,
index_t
HPerBlock
,
index_t
WPerBlock
,
index_t
InBlockLdsExtraW
,
typename
InBlockTransferThreadClusterLengths
,
typename
InBlockTransferThreadClusterArrangeOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
>
struct
DevicePermuteImpl
:
DevicePermute
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
>
{
using
BaseType
=
DevicePermute
<
NumDim
,
InDataType
,
OutDataType
,
ElementwiseOperation
>
;
using
typename
BaseType
::
Lengths
;
using
typename
BaseType
::
Strides
;
static_assert
(
3
<=
NumDim
,
"Only accept at least 3D dimension tensor"
);
static_assert
((
NumDim
-
2
)
<=
SrcVectorDim
&&
SrcVectorDim
<
NumDim
);
static_assert
((
NumDim
-
2
)
<=
DstVectorDim
&&
DstVectorDim
<
NumDim
);
static_assert
(
SrcVectorDim
!=
DstVectorDim
);
template
<
index_t
N
=
NumDim
>
static
auto
ConvertArrayToTuple
(
const
std
::
array
<
index_t
,
NumDim
>&
array
)
{
static_assert
(
1
<=
N
&&
N
<=
NumDim
);
return
generate_tuple
([
&
](
auto
I
)
{
return
array
[
I
];
},
Number
<
N
>
{});
}
static
auto
MakeDescriptor_N_H_W
(
const
Lengths
&
lengths
,
const
Strides
&
stride
)
{
// create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
// d[NumDim-1]]
const
auto
desc
=
make_naive_tensor_descriptor
(
ConvertArrayToTuple
(
lengths
),
ConvertArrayToTuple
(
stride
));
// merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2],
// d[NumDim-1]]
// => [N, H, W]
const
index_t
H
=
*
std
::
next
(
rbegin
(
lengths
));
const
index_t
W
=
*
rbegin
(
lengths
);
const
auto
desc_n_h_w
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
ConvertArrayToTuple
<
NumDim
-
2
>
(
lengths
)),
make_pass_through_transform
(
H
),
make_pass_through_transform
(
W
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
NumDim
-
2
>
{}),
Sequence
<
NumDim
-
2
>
{},
Sequence
<
NumDim
-
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
PadTensorDescriptor
(
desc_n_h_w
,
make_tuple
(
NPerBlock
,
HPerBlock
,
WPerBlock
),
Sequence
<
true
,
true
,
true
>
{});
}
using
InGridDesc
=
decltype
(
MakeDescriptor_N_H_W
({
1
,
1
},
{
1
,
1
}));
using
OutGridDesc
=
InGridDesc
;
using
GridwisePermute
=
GridwisePermute
<
InGridDesc
,
OutGridDesc
,
InDataType
,
OutDataType
,
ElementwiseOperation
,
BlockSize
,
NPerBlock
,
HPerBlock
,
WPerBlock
,
InBlockLdsExtraW
,
InBlockTransferThreadClusterLengths
,
InBlockTransferThreadClusterArrangeOrder
,
SrcVectorDim
-
(
NumDim
-
3
),
// calculate new SrcVectorDim for the merged descriptor
DstVectorDim
-
(
NumDim
-
3
),
// calculate new DstVectorDim for the merged descriptor
SrcScalarPerVector
,
DstScalarPerVector
>
;
using
Block2TileMap
=
typename
GridwisePermute
::
DefaultBlock2TileMap
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
Lengths
&
in_lengths
,
const
Strides
&
in_strides
,
const
Lengths
&
out_lengths
,
const
Strides
&
out_strides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
:
in_dev_buffer_
(
static_cast
<
const
InDataType
*>
(
in_dev_buffer
)),
out_dev_buffer_
(
static_cast
<
OutDataType
*>
(
out_dev_buffer
)),
in_grid_desc_
(
MakeDescriptor_N_H_W
(
in_lengths
,
in_strides
)),
out_grid_desc_
(
MakeDescriptor_N_H_W
(
out_lengths
,
out_strides
)),
in_lengths_
(
in_lengths
),
in_strides_
(
in_strides
),
out_lengths_
(
out_lengths
),
out_strides_
(
out_strides
),
elementwise_op_
(
elementwise_op
),
block_2_tile_map_
(
GridwisePermute
::
MakeDefaultBlock2TileMap
(
in_grid_desc_
))
{
}
const
InDataType
*
in_dev_buffer_
;
OutDataType
*
out_dev_buffer_
;
InGridDesc
in_grid_desc_
;
OutGridDesc
out_grid_desc_
;
Lengths
in_lengths_
;
Strides
in_strides_
;
Lengths
out_lengths_
;
Strides
out_strides_
;
ElementwiseOperation
elementwise_op_
;
Block2TileMap
block_2_tile_map_
;
};
struct
Invoker
:
BaseInvoker
{
static
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
index_t
grid_size
=
arg
.
block_2_tile_map_
.
CalculateGridSize
(
arg
.
in_grid_desc_
);
const
auto
kernel
=
kernel_nd_permute
<
GridwisePermute
,
InGridDesc
,
OutGridDesc
,
InDataType
,
OutDataType
,
ElementwiseOperation
,
Block2TileMap
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
in_grid_desc_
,
arg
.
out_grid_desc_
,
arg
.
in_dev_buffer_
,
arg
.
out_dev_buffer_
,
arg
.
elementwise_op_
,
arg
.
block_2_tile_map_
);
return
elapsed_time
;
}
float
Run
(
const
BaseArgument
*
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
final
{
const
auto
*
const
argument
=
dynamic_cast
<
const
Argument
*>
(
arg
);
if
(
!
argument
)
{
return
NAN
;
}
return
Run
(
*
argument
,
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
constexpr
auto
GetPaddedLength
=
[](
index_t
length
,
index_t
tile_length
)
{
return
math
::
integer_divide_ceil
(
length
,
tile_length
)
*
tile_length
;
};
constexpr
auto
IsScalarPerVectorValid
=
[](
index_t
length
,
index_t
stride
,
index_t
scalar_per_vector
)
{
if
(
stride
==
1
&&
length
%
scalar_per_vector
==
0
)
{
return
true
;
}
else
if
(
stride
!=
1
&&
scalar_per_vector
==
1
)
{
return
true
;
}
return
false
;
};
return
IsScalarPerVectorValid
(
arg
.
in_lengths_
[
SrcVectorDim
],
arg
.
in_strides_
[
SrcVectorDim
],
SrcScalarPerVector
)
&&
IsScalarPerVectorValid
(
GetPaddedLength
(
arg
.
in_lengths_
[
SrcVectorDim
],
(
SrcVectorDim
==
NumDim
-
2
?
HPerBlock
:
WPerBlock
)),
arg
.
in_strides_
[
SrcVectorDim
],
SrcScalarPerVector
)
&&
IsScalarPerVectorValid
(
arg
.
out_lengths_
[
DstVectorDim
],
arg
.
out_strides_
[
DstVectorDim
],
DstScalarPerVector
)
&&
IsScalarPerVectorValid
(
GetPaddedLength
(
arg
.
out_lengths_
[
DstVectorDim
],
(
DstVectorDim
==
NumDim
-
2
?
HPerBlock
:
WPerBlock
)),
arg
.
in_strides_
[
DstVectorDim
],
DstScalarPerVector
)
&&
GridwisePermute
::
CheckValidity
(
arg
.
in_grid_desc_
,
arg
.
out_grid_desc_
);
};
// override methods inherited from 'BaseOperator'
bool
IsSupportedArgument
(
const
BaseArgument
*
arg
)
override
final
{
const
auto
*
const
argument
=
dynamic_cast
<
const
Argument
*>
(
arg
);
if
(
!
argument
)
{
return
false
;
}
return
IsSupportedArgument
(
*
argument
);
}
// override methods inherited from 'DevicePermute'
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
Lengths
&
in_lengths
,
const
Strides
&
in_strides
,
const
Lengths
&
out_lengths
,
const
Strides
&
out_strides
,
const
void
*
in_dev_buffer
,
void
*
out_dev_buffer
,
ElementwiseOperation
elementwise_op
)
override
final
{
return
std
::
make_unique
<
Argument
>
(
in_lengths
,
in_strides
,
out_lengths
,
out_strides
,
in_dev_buffer
,
out_dev_buffer
,
elementwise_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
final
{
return
std
::
make_unique
<
Invoker
>
();
};
// other constructor methods
template
<
typename
...
Args
>
static
std
::
enable_if_t
<
std
::
is_constructible_v
<
Argument
,
Args
...
>
,
Argument
>
MakeArgument
(
Args
&&
...
args
)
noexcept
(
std
::
is_nothrow_constructible_v
<
Argument
,
Args
...
>
)
{
return
Argument
{
std
::
forward
<
Args
>
(
args
)...};
}
static
std
::
enable_if_t
<
std
::
is_default_constructible_v
<
Invoker
>
,
Invoker
>
MakeInvoker
()
noexcept
(
std
::
is_nothrow_default_constructible_v
<
Invoker
>
)
{
return
Invoker
{};
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
24f99138
...
@@ -232,6 +232,21 @@ struct Gelu
...
@@ -232,6 +232,21 @@ struct Gelu
}
}
};
};
struct
Sigmoid
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
1
/
(
ck
::
type_convert
<
T
>
(
1
)
+
exp
(
-
x
));
};
int32_t
divider_
=
1
;
};
}
// namespace element_wise
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp
View file @
24f99138
...
@@ -83,6 +83,8 @@ struct GridwiseElementwise_1D
...
@@ -83,6 +83,8 @@ struct GridwiseElementwise_1D
auto
in_global_buf_tuple
=
generate_tuple
(
auto
in_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
[
&
](
auto
I
)
{
static_assert
(
in_grid_1d_desc_tuple
[
I
].
GetNumOfDimension
()
==
1
);
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global_tuple
[
I
],
in_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
p_in_global_tuple
[
I
],
in_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
},
...
@@ -90,6 +92,8 @@ struct GridwiseElementwise_1D
...
@@ -90,6 +92,8 @@ struct GridwiseElementwise_1D
auto
out_global_buf_tuple
=
generate_tuple
(
auto
out_global_buf_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
[
&
](
auto
I
)
{
static_assert
(
out_grid_1d_desc_tuple
[
I
].
GetNumOfDimension
()
==
1
);
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global_tuple
[
I
],
out_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
p_out_global_tuple
[
I
],
out_grid_1d_desc_tuple
[
I
].
GetElementSpaceSize
());
},
},
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_naive_variance.hpp
View file @
24f99138
...
@@ -22,7 +22,6 @@ template <typename XDataType,
...
@@ -22,7 +22,6 @@ template <typename XDataType,
typename
AccDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
,
typename
GridDesc_K
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
...
@@ -30,7 +29,9 @@ template <typename XDataType,
...
@@ -30,7 +29,9 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
...
@@ -78,13 +79,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -78,13 +79,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_K
&
gamma_grid_desc_k
,
const
GridDesc_
M_
K
&
gamma_grid_desc_
m_
k
,
const
GridDesc_K
&
beta_grid_desc_k
,
const
GridDesc_
M_
K
&
beta_grid_desc_
m_
k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
AccDataType
epsilon
,
...
@@ -111,11 +113,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -111,11 +113,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
y_thread_buf
;
...
@@ -127,7 +132,7 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -127,7 +132,7 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_square_thread_buf
;
mean_square_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
var_
value
_buf
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>&
var_
thread
_buf
=
mean_square_thread_buf
;
mean_square_thread_buf
;
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
@@ -145,11 +150,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -145,11 +150,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_K
=
Sequence
<
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
AccDataType
,
...
@@ -169,27 +171,34 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -169,27 +171,34 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
auto
threadwise_gamma_load
=
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
AccDataType
,
AccDataType
,
GridDesc_K
,
GridDesc_
M_
K
,
decltype
(
thread_buffer_desc_k
),
decltype
(
thread_buffer_desc_
m_
k
),
ThreadBufferLengths_K
,
ThreadBufferLengths_
M_
K
,
Sequence
<
0
>
,
ThreadBufferDimAccessOrder
,
0
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
GammaSrcVectorSize
,
1
,
1
,
true
>
(
true
>
(
gamma_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
thread_m_cluster_id
*
MThreadSliceSize
,
AccDataType
,
thread_k_cluster_id
*
KThreadSliceSize
));
GridDesc_K
,
decltype
(
thread_buffer_desc_k
),
auto
threadwise_beta_load
=
ThreadBufferLengths_K
,
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
Sequence
<
0
>
,
AccDataType
,
0
,
GridDesc_M_K
,
BetaSrcVectorSize
,
decltype
(
thread_buffer_desc_m_k
),
1
,
ThreadBufferLengths_M_K
,
true
>
(
ThreadBufferDimAccessOrder
,
beta_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
BetaSrcVectorDim
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_y_store
=
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
...
@@ -212,9 +221,6 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -212,9 +221,6 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
// Copy x from Cache
// Copy x from Cache
// one pass: fwd, second pass: bwd
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
-
K_BlockTileSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
constexpr
auto
thread_copy_bwd_step_m_k
=
...
@@ -224,13 +230,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -224,13 +230,14 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_k
.
GetElementSpaceSize
());
p_gamma_global
,
gamma_grid_desc_
m_
k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_k
.
GetElementSpaceSize
());
p_beta_global
,
beta_grid_desc_
m_
k
.
GetElementSpaceSize
());
// E(x), E[x^2], var(x)
// E(x), E[x^2], var(x)
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I1
];
// FIXME: Should not hack the transform from deviceOP
int
reduce_length
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
index_t
reducedTiles
=
0
;
index_t
reducedTiles
=
0
;
do
do
...
@@ -271,17 +278,16 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -271,17 +278,16 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
mean_square_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
/
reduce_length
;
mean_square_thread_buf
(
I
)
=
mean_square_thread_buf
(
I
)
/
reduce_length
;
// var(x) = E[x^2] - E[x]^2
// var(x) = E[x^2] - E[x]^2
var_
value
_buf
(
I
)
=
var_
thread
_buf
(
I
)
=
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
mean_square_thread_buf
(
I
)
-
(
mean_thread_buf
(
I
)
*
mean_thread_buf
(
I
));
});
});
// y = (x - E[x]) / sqrt(var[x] + epsilon)
// y = (x - E[x]) / sqrt(var[x] + epsilon)
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_tail_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_tail_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
reducedTiles
=
0
;
reducedTiles
=
0
;
...
@@ -296,10 +302,10 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -296,10 +302,10 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
x_thread_buf
);
x_thread_buf
);
}
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_k
,
threadwise_gamma_load
.
Run
(
gamma_grid_desc_
m_
k
,
gamma_global_val_buf
,
gamma_global_val_buf
,
thread_buffer_desc_k
,
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
),
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
gamma_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
@@ -307,23 +313,21 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -307,23 +313,21 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// normalize
// normalize
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
sqrt
(
var_
value
_buf
(
iM
)
+
epsilon
);
sqrt
(
var_
thread
_buf
(
iM
)
+
epsilon
);
// gamma
// gamma
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
});
});
threadwise_beta_load
.
Run
(
beta_grid_desc_k
,
threadwise_beta_load
.
Run
(
beta_grid_desc_
m_
k
,
beta_global_val_buf
,
beta_global_val_buf
,
thread_buffer_desc_k
,
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
),
make_tuple
(
I0
,
I0
),
beta_thread_buf
);
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
@@ -331,11 +335,9 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -331,11 +335,9 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// beta
// beta
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
});
});
...
@@ -346,8 +348,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
...
@@ -346,8 +348,8 @@ struct GridwiseLayernormNaiveVariance_mk_to_mk
y_global_val_buf
);
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
++
reducedTiles
;
++
reducedTiles
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm_welford_variance.hpp
View file @
24f99138
...
@@ -19,7 +19,6 @@ template <typename XDataType,
...
@@ -19,7 +19,6 @@ template <typename XDataType,
typename
AccDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_M_K
,
typename
GridDesc_K
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
KThreadClusterSize
,
...
@@ -27,7 +26,9 @@ template <typename XDataType,
...
@@ -27,7 +26,9 @@ template <typename XDataType,
index_t
KThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorDim
,
index_t
XSrcVectorSize
,
index_t
XSrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorDim
,
index_t
YDstVectorDim
,
index_t
YDstVectorSize
,
index_t
YDstVectorSize
,
...
@@ -70,6 +71,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -70,6 +71,7 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
...
@@ -77,7 +79,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -77,7 +79,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
__device__
static
int
GetKPerThread
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
__device__
static
int
GetKPerThread
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
int
thread_k_cluster_id
)
int
thread_k_cluster_id
)
{
{
int
kPerBlock
=
x_grid_desc_m_k
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I1
];
// FIXME: Should not hack the transform from deviceOP
int
kPerBlock
=
x_grid_desc_m_k
.
GetTransforms
()[
I2
].
GetUpperLengths
()[
I0
];
int
kPerThread
=
int
kPerThread
=
kPerBlock
<
K_BlockTileSize
?
0
:
KThreadSliceSize
*
(
kPerBlock
/
K_BlockTileSize
);
kPerBlock
<
K_BlockTileSize
?
0
:
KThreadSliceSize
*
(
kPerBlock
/
K_BlockTileSize
);
int
kPerBlockTail
=
kPerBlock
-
kPerThread
*
KThreadClusterSize
;
int
kPerBlockTail
=
kPerBlock
-
kPerThread
*
KThreadClusterSize
;
...
@@ -94,8 +97,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -94,8 +97,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
}
}
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
__device__
static
void
Run
(
const
GridDesc_M_K
&
x_grid_desc_m_k
,
const
GridDesc_K
&
gamma_grid_desc_k
,
const
GridDesc_
M_
K
&
gamma_grid_desc_
m_
k
,
const
GridDesc_K
&
beta_grid_desc_k
,
const
GridDesc_
M_
K
&
beta_grid_desc_
m_
k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
const
GridDesc_M_K
&
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
AccDataType
epsilon
,
...
@@ -116,11 +119,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -116,11 +119,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
x_thread_buf
;
x_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>&
beta_thread_buf
=
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
y_thread_buf
;
y_thread_buf
;
...
@@ -137,11 +143,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -137,11 +143,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_M_K
=
Sequence
<
MThreadSliceSize
,
KThreadSliceSize
>
;
using
ThreadBufferLengths_K
=
Sequence
<
KThreadSliceSize
>
;
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
thread_buffer_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
KThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadSliceSize
>
{}));
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
auto
threadwise_x_load
=
ThreadwiseTensorSliceTransfer_v2
<
XDataType
,
AccDataType
,
AccDataType
,
...
@@ -161,27 +164,34 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -161,27 +164,34 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
auto
threadwise_gamma_load
=
auto
threadwise_gamma_load
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
AccDataType
,
AccDataType
,
GridDesc_K
,
GridDesc_
M_
K
,
decltype
(
thread_buffer_desc_k
),
decltype
(
thread_buffer_desc_
m_
k
),
ThreadBufferLengths_K
,
ThreadBufferLengths_
M_
K
,
Sequence
<
0
>
,
ThreadBufferDimAccessOrder
,
0
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
GammaSrcVectorSize
,
1
,
1
,
true
>
(
true
>
(
gamma_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
gamma_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
auto
threadwise_beta_load
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
thread_m_cluster_id
*
MThreadSliceSize
,
AccDataType
,
thread_k_cluster_id
*
KThreadSliceSize
));
GridDesc_K
,
decltype
(
thread_buffer_desc_k
),
auto
threadwise_beta_load
=
ThreadBufferLengths_K
,
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
Sequence
<
0
>
,
AccDataType
,
0
,
GridDesc_M_K
,
BetaSrcVectorSize
,
decltype
(
thread_buffer_desc_m_k
),
1
,
ThreadBufferLengths_M_K
,
true
>
(
ThreadBufferDimAccessOrder
,
beta_grid_desc_k
,
make_multi_index
(
thread_k_cluster_id
*
KThreadSliceSize
));
BetaSrcVectorDim
,
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_m_k
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_k_cluster_id
*
KThreadSliceSize
));
auto
threadwise_y_store
=
auto
threadwise_y_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
...
@@ -204,9 +214,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -204,9 +214,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
// Copy x from Cache
// Copy x from Cache
// one pass: fwd, second pass: bwd
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_k
=
make_multi_index
(
SweepOnce
?
0
:
-
K_BlockTileSize
);
constexpr
auto
thread_copy_fwd_step_m_k
=
constexpr
auto
thread_copy_fwd_step_m_k
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
thread_copy_bwd_step_m_k
=
constexpr
auto
thread_copy_bwd_step_m_k
=
...
@@ -216,10 +223,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -216,10 +223,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
p_x_global
,
x_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_k
.
GetElementSpaceSize
());
p_gamma_global
,
gamma_grid_desc_
m_
k
.
GetElementSpaceSize
());
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
beta_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_beta_global
,
beta_grid_desc_k
.
GetElementSpaceSize
());
p_beta_global
,
beta_grid_desc_
m_
k
.
GetElementSpaceSize
());
auto
threadwise_welford
=
ThreadwiseWelford
();
auto
threadwise_welford
=
ThreadwiseWelford
();
threadwise_welford
.
max_count_
=
GetKPerThread
(
x_grid_desc_m_k
,
thread_k_cluster_id
);
threadwise_welford
.
max_count_
=
GetKPerThread
(
x_grid_desc_m_k
,
thread_k_cluster_id
);
...
@@ -250,11 +257,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -250,11 +257,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
});
});
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_m_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_m_k
;
auto
thread_copy_tail_k
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step_k
;
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_tail_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_tail_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_tail_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_tail_m_k
);
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
num_k_block_tile_iteration
;
++
reducedTiles
)
...
@@ -268,10 +274,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -268,10 +274,10 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
x_thread_buf
);
x_thread_buf
);
}
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_k
,
threadwise_gamma_load
.
Run
(
gamma_grid_desc_
m_
k
,
gamma_global_val_buf
,
gamma_global_val_buf
,
thread_buffer_desc_k
,
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
),
make_tuple
(
I0
,
I0
),
gamma_thread_buf
);
gamma_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
@@ -279,8 +285,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -279,8 +285,6 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// normalize
// normalize
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
(
x_thread_buf
(
Number
<
offset_m_k
>
{})
-
mean_thread_buf
(
iM
))
/
...
@@ -288,14 +292,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -288,14 +292,14 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
// gamma
// gamma
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
*
gamma_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
});
});
threadwise_beta_load
.
Run
(
beta_grid_desc_k
,
threadwise_beta_load
.
Run
(
beta_grid_desc_
m_
k
,
beta_global_val_buf
,
beta_global_val_buf
,
thread_buffer_desc_k
,
thread_buffer_desc_
m_
k
,
make_tuple
(
I0
),
make_tuple
(
I0
,
I0
),
beta_thread_buf
);
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
...
@@ -303,11 +307,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -303,11 +307,9 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
constexpr
auto
offset_m_k
=
constexpr
auto
offset_m_k
=
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
thread_buffer_desc_m_k
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
constexpr
auto
offset_k
=
thread_buffer_desc_k
.
CalculateOffset
(
make_tuple
(
iK
));
// beta
// beta
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
=
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_k
>
{});
y_thread_buf
(
Number
<
offset_m_k
>
{})
+
beta_thread_buf
(
Number
<
offset_
m_
k
>
{});
});
});
});
});
...
@@ -318,8 +320,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
...
@@ -318,8 +320,8 @@ struct GridwiseLayernormWelfordVariance_mk_to_mk
y_global_val_buf
);
y_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
gamma_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_k
,
thread_copy_bwd_step_k
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
beta_grid_desc_
m_
k
,
thread_copy_bwd_step_
m_
k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
threadwise_y_store
.
MoveDstSliceWindow
(
y_grid_desc_m_k
,
thread_copy_bwd_step_m_k
);
}
}
}
}
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment