Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e6bb1dd7
Unverified
Commit
e6bb1dd7
authored
Jul 19, 2024
by
Po Yen Chen
Committed by
GitHub
Jul 19, 2024
Browse files
Merge branch 'develop' into feature/check-window-lengths
parents
9d6a3704
ab250afd
Changes
317
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
713 additions
and
224 deletions
+713
-224
include/ck_tile/core/numeric/integral_constant.hpp
include/ck_tile/core/numeric/integral_constant.hpp
+0
-1
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+12
-1
include/ck_tile/core/numeric/null_type.hpp
include/ck_tile/core/numeric/null_type.hpp
+13
-0
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+10
-1
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+41
-17
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+13
-6
include/ck_tile/core/tensor/null_tile_window.hpp
include/ck_tile/core/tensor/null_tile_window.hpp
+2
-0
include/ck_tile/core/tensor/store_tile.hpp
include/ck_tile/core/tensor/store_tile.hpp
+1
-1
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+41
-13
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+1
-0
include/ck_tile/core/tensor/tile_elementwise.hpp
include/ck_tile/core/tensor/tile_elementwise.hpp
+90
-11
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+150
-10
include/ck_tile/core/tensor/update_tile.hpp
include/ck_tile/core/tensor/update_tile.hpp
+55
-0
include/ck_tile/core/utility/philox_rand.hpp
include/ck_tile/core/utility/philox_rand.hpp
+89
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+3
-0
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+15
-10
include/ck_tile/host/device_memory.hpp
include/ck_tile/host/device_memory.hpp
+41
-18
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+32
-4
include/ck_tile/host/kernel_launch.hpp
include/ck_tile/host/kernel_launch.hpp
+71
-131
include/ck_tile/host/reference/reference_batched_dropout.hpp
include/ck_tile/host/reference/reference_batched_dropout.hpp
+33
-0
No files found.
Too many changes to show.
To preserve performance only
317 of 317+
files are displayed.
Plain diff
Email patch
include/ck_tile/core/numeric/integral_constant.hpp
View file @
e6bb1dd7
...
@@ -56,7 +56,6 @@ CK_TILE_LEFT_UNARY_OP(+)
...
@@ -56,7 +56,6 @@ CK_TILE_LEFT_UNARY_OP(+)
CK_TILE_LEFT_UNARY_OP
(
-
)
CK_TILE_LEFT_UNARY_OP
(
-
)
CK_TILE_LEFT_UNARY_OP
(
~
)
CK_TILE_LEFT_UNARY_OP
(
~
)
CK_TILE_LEFT_UNARY_OP
(
!
)
CK_TILE_LEFT_UNARY_OP
(
!
)
CK_TILE_LEFT_UNARY_OP
(
*
)
CK_TILE_BINARY_OP
(
+
)
CK_TILE_BINARY_OP
(
+
)
CK_TILE_BINARY_OP
(
-
)
CK_TILE_BINARY_OP
(
-
)
...
...
include/ck_tile/core/numeric/math.hpp
View file @
e6bb1dd7
...
@@ -519,7 +519,7 @@ CK_TILE_DEVICE
...
@@ -519,7 +519,7 @@ CK_TILE_DEVICE
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
CK_TILE_DEVICE
CK_TILE_DEVICE
float
exp
(
float
x
)
{
return
__
expf
(
x
);
};
float
exp
(
float
x
)
{
return
__
ocml_exp_f32
(
x
);
};
CK_TILE_HOST
CK_TILE_HOST
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
...
@@ -536,4 +536,15 @@ float log(float x) { return __logf(x); };
...
@@ -536,4 +536,15 @@ float log(float x) { return __logf(x); };
CK_TILE_HOST
CK_TILE_HOST
float
log
(
float
x
)
{
return
std
::
logf
(
x
);
};
float
log
(
float
x
)
{
return
std
::
logf
(
x
);
};
CK_TILE_DEVICE
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
// TODO: this is hacky, we use u16
return
__builtin_amdgcn_sad_u16
(
x
,
y
,
acc
);
}
CK_TILE_HOST
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
return
(
x
>
y
?
(
x
-
y
)
:
(
y
-
x
))
+
acc
;
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/numeric/null_type.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
namespace
ck_tile
{
struct
null_type
{
};
}
// namespace ck_tile
include/ck_tile/core/numeric/vector_type.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -144,6 +144,15 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16)));
...
@@ -144,6 +144,15 @@ using int8x16_t = int8_t __attribute((ext_vector_type(16)));
using
int8x32_t
=
int8_t
__attribute
((
ext_vector_type
(
32
)));
using
int8x32_t
=
int8_t
__attribute
((
ext_vector_type
(
32
)));
using
int8x64_t
=
int8_t
__attribute
((
ext_vector_type
(
64
)));
using
int8x64_t
=
int8_t
__attribute
((
ext_vector_type
(
64
)));
// ui8
// using uint8_t
using
uint8x2_t
=
uint8_t
__attribute
((
ext_vector_type
(
2
)));
using
uint8x4_t
=
uint8_t
__attribute
((
ext_vector_type
(
4
)));
using
uint8x8_t
=
uint8_t
__attribute
((
ext_vector_type
(
8
)));
using
uint8x16_t
=
uint8_t
__attribute
((
ext_vector_type
(
16
)));
using
uint8x32_t
=
uint8_t
__attribute
((
ext_vector_type
(
32
)));
using
uint8x64_t
=
uint8_t
__attribute
((
ext_vector_type
(
64
)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8
// f8
// using fp8_t
// using fp8_t
...
...
include/ck_tile/core/tensor/buffer_view.hpp
View file @
e6bb1dd7
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/generic_memory_space_atomic.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
...
@@ -68,6 +69,8 @@ struct buffer_view<address_space_enum::generic,
...
@@ -68,6 +69,8 @@ struct buffer_view<address_space_enum::generic,
{
{
}
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
{
return
address_space_enum
::
generic
;
return
address_space_enum
::
generic
;
...
@@ -223,23 +226,34 @@ struct buffer_view<address_space_enum::global,
...
@@ -223,23 +226,34 @@ struct buffer_view<address_space_enum::global,
T
*
p_data_
=
nullptr
;
T
*
p_data_
=
nullptr
;
BufferSizeType
buffer_size_
;
BufferSizeType
buffer_size_
;
int32x4_t
cached_buf_res_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
invalid_element_value_
{}
:
p_data_
{},
buffer_size_
{},
cached_buf_res_
{
0
},
invalid_element_value_
{}
{
{
}
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
0
}
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
cached_buf_res_
{
0
},
invalid_element_value_
{
0
}
{
{
}
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
,
BufferSizeType
buffer_size
,
T
invalid_element_value
)
T
invalid_element_value
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
invalid_element_value
}
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
cached_buf_res_
{
0
},
invalid_element_value_
{
invalid_element_value
}
{
}
// this is non constexpr intentially (will call some intrinsic internally)
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE
void
init_raw
()
{
{
cached_buf_res_
=
make_wave_buffer_resource
(
p_data_
,
buffer_size_
*
sizeof
(
type
));
}
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
...
@@ -332,12 +346,15 @@ struct buffer_view<address_space_enum::global,
...
@@ -332,12 +346,15 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
get_raw
(
remove_cvref_t
<
X
>&
dst
,
index_t
i
,
bool
is_valid_element
)
const
index_t
i
,
bool
is_valid_element
,
bool_constant
<
pre_nop
>
=
{})
const
{
{
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -348,18 +365,21 @@ struct buffer_view<address_space_enum::global,
...
@@ -348,18 +365,21 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
p_data_
,
i
,
buffer_size_
,
is_valid_element
);
dst
,
cached_buf_res_
,
i
,
is_valid_element
,
bool_constant
<
pre_nop
>
{}
);
}
}
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
CK_TILE_DEVICE
constexpr
auto
async_get_raw
(
remove_cvref_t
<
T
>*
smem
,
async_get
(
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
bool
/*is_valid_element*/
)
const
index_t
i
,
bool
/*is_valid_element*/
,
bool_constant
<
pre_nop
>
=
{})
const
{
{
// X is vector of T
// X is vector of T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -370,8 +390,8 @@ struct buffer_view<address_space_enum::global,
...
@@ -370,8 +390,8 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_async_buffer_load_with_oob
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
amd_async_buffer_load_with_oob
_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
p_data_
,
i
,
buffer_size_
);
smem
,
cached_buf_res_
,
i
,
bool_constant
<
pre_nop
>
{}
);
}
}
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
...
@@ -507,10 +527,10 @@ struct buffer_view<address_space_enum::global,
...
@@ -507,10 +527,10 @@ struct buffer_view<address_space_enum::global,
bool
constexpr
use_amd_buffer_addressing
=
false
;
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
#endif
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
use_amd_buffer_addressing
)
if
constexpr
(
use_amd_buffer_addressing
)
{
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
}
...
@@ -518,7 +538,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -518,7 +538,7 @@ struct buffer_view<address_space_enum::global,
{
{
if
(
is_valid_element
)
if
(
is_valid_element
)
{
{
atomic_add
<
X
>
(
c_style_pointer_cast
<
X
*
>
(
&
p_data_
[
i
]
)
,
x
);
atomic_add
_g
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
&
p_data_
[
i
],
x
);
}
}
}
}
}
}
...
@@ -547,16 +567,16 @@ struct buffer_view<address_space_enum::global,
...
@@ -547,16 +567,16 @@ struct buffer_view<address_space_enum::global,
bool
constexpr
use_amd_buffer_addressing
=
false
;
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
#endif
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
use_amd_buffer_addressing
)
if
constexpr
(
use_amd_buffer_addressing
)
{
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
}
else
if
(
is_valid_element
)
else
if
(
is_valid_element
)
{
{
atomic_max
<
X
>
(
c_style_pointer_cast
<
X
*
>
(
&
p_data_
[
i
]
)
,
x
);
atomic_max
_g
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
&
p_data_
[
i
],
x
);
}
}
}
}
...
@@ -626,6 +646,8 @@ struct buffer_view<address_space_enum::lds,
...
@@ -626,6 +646,8 @@ struct buffer_view<address_space_enum::lds,
{
{
}
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
{
return
address_space_enum
::
lds
;
return
address_space_enum
::
lds
;
...
@@ -908,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr,
...
@@ -908,6 +930,8 @@ struct buffer_view<address_space_enum::vgpr,
{
{
}
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
{
return
address_space_enum
::
vgpr
;
return
address_space_enum
::
vgpr
;
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
e6bb1dd7
...
@@ -36,30 +36,37 @@ template <typename T,
...
@@ -36,30 +36,37 @@ template <typename T,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
{
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{});
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{}
,
bool_constant
<
pre_nop
>
{}
);
}
}
template
<
typename
LdsTileWindow_
,
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
>
index_t
NumCoord
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
)
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
{
return
tile_window
.
async_load
(
lds_tile
);
return
tile_window
.
async_load_raw
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
...
...
include/ck_tile/core/tensor/null_tile_window.hpp
View file @
e6bb1dd7
...
@@ -35,6 +35,8 @@ struct null_tile_window
...
@@ -35,6 +35,8 @@ struct null_tile_window
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
BottomTensorIndex
{};
}
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
BottomTensorIndex
{};
}
CK_TILE_DEVICE
void
init_raw
()
{}
WindowLengths
window_lengths_
;
WindowLengths
window_lengths_
;
};
};
...
...
include/ck_tile/core/tensor/store_tile.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/core/tensor/tensor_view.hpp
View file @
e6bb1dd7
...
@@ -16,7 +16,9 @@
...
@@ -16,7 +16,9 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
BufferView_
,
typename
TensorDesc_
>
template
<
typename
BufferView_
,
typename
TensorDesc_
,
memory_operation_enum
DstInMemOp_
=
memory_operation_enum
::
set
>
struct
tensor_view
struct
tensor_view
{
{
using
buffer_view
=
remove_reference_t
<
BufferView_
>
;
using
buffer_view
=
remove_reference_t
<
BufferView_
>
;
...
@@ -24,6 +26,7 @@ struct tensor_view
...
@@ -24,6 +26,7 @@ struct tensor_view
using
TensorDesc
=
remove_cvref_t
<
TensorDesc_
>
;
using
TensorDesc
=
remove_cvref_t
<
TensorDesc_
>
;
using
TensorIndex
=
array
<
index_t
,
TensorDesc
::
get_num_of_top_dimension
()
>
;
using
TensorIndex
=
array
<
index_t
,
TensorDesc
::
get_num_of_top_dimension
()
>
;
using
TensorCoord
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
TensorIndex
{}));
using
TensorCoord
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
TensorIndex
{}));
static
constexpr
auto
DstInMemOp
=
DstInMemOp_
;
CK_TILE_HOST_DEVICE
constexpr
tensor_view
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
tensor_view
()
=
default
;
...
@@ -33,6 +36,8 @@ struct tensor_view
...
@@ -33,6 +36,8 @@ struct tensor_view
{
{
}
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{
buf_
.
init_raw
();
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_tensor_descriptor
()
const
{
return
desc_
;
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_tensor_descriptor
()
const
{
return
desc_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension
()
...
@@ -82,30 +87,34 @@ struct tensor_view
...
@@ -82,30 +87,34 @@ struct tensor_view
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
void
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
const
TensorCoord
&
coord
,
const
TensorCoord
&
coord
,
bool_constant
<
oob_conditional_check
>
=
{}
,
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
>(
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
dst
,
dst
,
coord
.
get_offset
(),
coord
.
get_offset
(),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
));
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
X
,
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
remove_cvref_t
<
DataType
>*
smem
,
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
_raw
(
const
TensorCoord
&
coord
)
const
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
bool_constant
<
pre_nop
>
=
{}
)
const
{
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
true
/*not used*/
);
return
buf_
.
template
async_get_raw
<
X
>(
smem
,
coord
.
get_offset
(),
true
/*not used*/
,
bool_constant
<
pre_nop
>
{});
}
}
// X is vector of DataType.
// X is vector of DataType.
...
@@ -140,6 +149,23 @@ struct tensor_view
...
@@ -140,6 +149,23 @@ struct tensor_view
x
);
x
);
}
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements
(
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
CK_TILE_HOST_DEVICE
void
print
()
const
{
{
printf
(
"tensor_view{"
);
printf
(
"tensor_view{"
);
...
@@ -178,6 +204,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
...
@@ -178,6 +204,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* p,
}
}
template
<
address_space_enum
BufferAddressSpace
=
address_space_enum
::
generic
,
template
<
address_space_enum
BufferAddressSpace
=
address_space_enum
::
generic
,
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
,
typename
DataType
,
typename
DataType
,
typename
...
Lengths
,
typename
...
Lengths
,
typename
...
Strides
,
typename
...
Strides
,
...
@@ -198,7 +225,7 @@ make_naive_tensor_view(DataType* p,
...
@@ -198,7 +225,7 @@ make_naive_tensor_view(DataType* p,
auto
buffer_view
=
make_buffer_view
<
BufferAddressSpace
>
(
p
,
desc
.
get_element_space_size
());
auto
buffer_view
=
make_buffer_view
<
BufferAddressSpace
>
(
p
,
desc
.
get_element_space_size
());
return
tensor_view
<
decltype
(
buffer_view
),
decltype
(
desc
)
>
{
buffer_view
,
desc
};
return
tensor_view
<
decltype
(
buffer_view
),
decltype
(
desc
)
,
DstInMemOp
>
{
buffer_view
,
desc
};
}
}
template
<
address_space_enum
BufferAddressSpace
=
address_space_enum
::
generic
,
template
<
address_space_enum
BufferAddressSpace
=
address_space_enum
::
generic
,
...
@@ -232,8 +259,9 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& ol
...
@@ -232,8 +259,9 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView& ol
NewLowerDimensionOldVisibleIdss
{},
NewLowerDimensionOldVisibleIdss
{},
NewUpperDimensionNewVisibleIdss
{});
NewUpperDimensionNewVisibleIdss
{});
return
tensor_view
<
typename
OldTensorView
::
buffer_view
,
remove_cvref_t
<
decltype
(
new_desc
)
>>
{
return
tensor_view
<
typename
OldTensorView
::
buffer_view
,
old_tensor_view
.
buf_
,
new_desc
};
remove_cvref_t
<
decltype
(
new_desc
)
>
,
remove_cvref_t
<
OldTensorView
>::
DstInMemOp
>
{
old_tensor_view
.
buf_
,
new_desc
};
}
}
template
<
typename
TensorView
,
template
<
typename
TensorView
,
...
...
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
e6bb1dd7
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/meta_data_buffer.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
...
...
include/ck_tile/core/tensor/tile_elementwise.hpp
View file @
e6bb1dd7
...
@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
...
@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// sub-dword tensor...
// sub-dword tensor...
template
<
typename
DstrTensors
,
index_t
v
>
template
<
typename
DstrTensors
,
index_t
v
,
bool
skip_subdword_opt
=
false
>
CK_TILE_DEVICE
void
set_tile
(
DstrTensors
&
dstr_tensor
,
number
<
v
>
)
CK_TILE_DEVICE
void
set_tile
(
DstrTensors
&
dstr_tensor
,
number
<
v
>
,
bool_constant
<
skip_subdword_opt
>
=
{})
{
{
constexpr
index_t
tensor_bytes
=
using
elem_type
=
typename
DstrTensors
::
DataType
;
DstrTensors
::
get_thread_buffer_size
()
*
sizeof
(
typename
DstrTensors
::
DataType
);
constexpr
index_t
elem_size
=
sizeof
(
elem_type
);
if
constexpr
(
v
==
0
&&
tensor_bytes
%
4
==
0
)
constexpr
index_t
tensor_bytes
=
DstrTensors
::
get_thread_buffer_size
()
*
elem_size
;
// # bytes per write = 4
if
constexpr
(
v
==
0
&&
tensor_bytes
%
4
==
0
&&
!
skip_subdword_opt
)
{
{
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
auto
&
buffer
=
dstr_tensor
.
get_thread_buffer
();
static_for
<
0
,
tensor_bytes
/
4
,
1
>
{}([
&
](
auto
i_write
)
{
if
constexpr
(
elem_size
==
1
)
{
// # elements per write = 4
constexpr
auto
values
=
ext_vector_t
<
elem_type
,
4
>
{
0
,
0
,
0
,
0
};
buffer
[
i_write
*
4
+
0
]
=
values
.
x
;
buffer
[
i_write
*
4
+
1
]
=
values
.
y
;
buffer
[
i_write
*
4
+
2
]
=
values
.
z
;
buffer
[
i_write
*
4
+
3
]
=
values
.
w
;
}
else
if
constexpr
(
elem_size
==
2
)
{
// # elements per write = 2
constexpr
auto
values
=
ext_vector_t
<
elem_type
,
2
>
{
0
,
0
};
buffer
[
i_write
*
2
+
0
]
=
values
.
x
;
buffer
[
i_write
*
2
+
1
]
=
values
.
y
;
}
else
if
constexpr
(
elem_size
==
4
)
{
// # elements per write = 1
constexpr
elem_type
value
=
0
;
buffer
[
i_write
]
=
value
;
}
else
{
static_assert
(
false
,
"type not supported"
);
}
});
#else
using
dvec_t
=
array
<
index_t
,
tensor_bytes
/
4
>
;
using
dvec_t
=
array
<
index_t
,
tensor_bytes
/
4
>
;
auto
&
tensor
=
reinterpret_cast
<
dvec_t
&>
(
dstr_tensor
.
get_thread_buffer
());
auto
&
tensor
=
reinterpret_cast
<
dvec_t
&>
(
dstr_tensor
.
get_thread_buffer
());
for
(
auto
i
=
0
;
i
<
tensor
.
size
();
i
++
)
for
(
auto
i
=
0
;
i
<
tensor
.
size
();
i
++
)
tensor
.
get
(
i
)
=
v
;
tensor
.
get
(
i
)
=
v
;
#endif
}
}
else
else
{
{
tile_elementwise_inout
(
tile_elementwise_inout
([](
auto
&
x
)
{
x
=
type_convert
<
elem_type
,
index_t
>
(
v
);
},
[](
auto
&
x
)
{
x
=
type_convert
<
typename
DstrTensors
::
DataType
,
index_t
>
(
v
);
},
dstr_tensor
);
dstr_tensor
);
}
}
}
}
...
@@ -110,9 +150,9 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
...
@@ -110,9 +150,9 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
namespace
impl
{
namespace
impl
{
// TODO: this is ugly
// TODO: this is ugly
template
<
typename
OutDataType
,
typename
InTensor
>
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp8
x4
(
const
InTensor
&
in_dstr_tensors
)
CK_TILE_DEVICE
auto
cast_tile_pk_fp8
_fp32
(
const
InTensor
&
in_dstr_tensors
)
{
{
#if defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__)
#if defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
// This API is designed to use the _pk_ serious of function
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
...
@@ -156,6 +196,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
...
@@ -156,6 +196,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
#endif
#endif
}
}
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp16_fp32
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
constexpr
index_t
thread_buffer_size
=
InTensor
::
get_thread_buffer_size
();
static_assert
(
thread_buffer_size
%
2
==
0
);
constexpr
index_t
thread_buffer_size_pk
=
thread_buffer_size
/
2
;
auto
out_dstr_tensor
=
make_static_distributed_tensor
<
OutDataType
>
(
in_tile_dstr
);
// TODO: this is rtz cvt, need be very careful
for
(
index_t
i
=
0
;
i
<
thread_buffer_size_pk
;
i
++
)
{
auto
o
=
__builtin_amdgcn_cvt_pkrtz
(
in_dstr_tensors
.
get_thread_buffer
()[
2
*
i
+
0
],
in_dstr_tensors
.
get_thread_buffer
()[
2
*
i
+
1
]);
out_dstr_tensor
.
get_thread_buffer
().
at
(
2
*
i
+
0
)
=
o
.
x
;
out_dstr_tensor
.
get_thread_buffer
().
at
(
2
*
i
+
1
)
=
o
.
y
;
}
return
out_dstr_tensor
;
#else
// fallback
return
tile_elementwise_in
(
type_convert
<
OutDataType
,
typename
InTensor
::
DataType
>
,
in_dstr_tensors
);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
#if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
...
@@ -229,8 +300,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
...
@@ -229,8 +300,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
float
>
&&
float
>
&&
(
SrcTensor
::
get_thread_buffer_size
()
%
4
==
0
))
(
SrcTensor
::
get_thread_buffer_size
()
%
4
==
0
))
{
{
return
impl
::
cast_tile_pk_fp8
x4
<
DstType
,
SrcTensor
>
(
src_tensor
);
return
impl
::
cast_tile_pk_fp8
_fp32
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
}
#if CK_TILE_USE_PK_FP16_TILE_CAST
else
if
constexpr
(
std
::
is_same_v
<
DstType
,
fp16_t
>
&&
std
::
is_same_v
<
typename
SrcTensor
::
DataType
,
float
>
&&
(
SrcTensor
::
get_thread_buffer_size
()
%
2
==
0
))
{
return
impl
::
cast_tile_pk_fp16_fp32
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
#endif
#if CK_TILE_USE_SUBDWORD_TILE_CAST
#if CK_TILE_USE_SUBDWORD_TILE_CAST
else
if
constexpr
(
sizeof
(
DstType
)
<
4
||
sizeof
(
typename
SrcTensor
::
DataType
)
<
4
)
else
if
constexpr
(
sizeof
(
DstType
)
<
4
||
sizeof
(
typename
SrcTensor
::
DataType
)
<
4
)
{
{
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
e6bb1dd7
...
@@ -355,9 +355,10 @@ struct tile_window_with_static_distribution
...
@@ -355,9 +355,10 @@ struct tile_window_with_static_distribution
return
dst_tensor
;
return
dst_tensor
;
}
}
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
>
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
using
Traits
=
load_store_traits
;
...
@@ -384,7 +385,13 @@ struct tile_window_with_static_distribution
...
@@ -384,7 +385,13 @@ struct tile_window_with_static_distribution
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
// data index [y0, y1, ...]
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
...
@@ -395,7 +402,8 @@ struct tile_window_with_static_distribution
...
@@ -395,7 +402,8 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
dst_vec_tbuf
.
template
at
<
d
/
Traits
::
ScalarPerVector
>(),
dst_vec_tbuf
.
template
at
<
d
/
Traits
::
ScalarPerVector
>(),
bottom_tensor_thread_coord
,
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
@@ -410,12 +418,17 @@ struct tile_window_with_static_distribution
...
@@ -410,12 +418,17 @@ struct tile_window_with_static_distribution
}
}
});
});
});
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm
volatile
(
"; this inline asm is workaround to prevent compiler from using too much "
"scratch memory"
::
);
#endif
}
}
// TODO: currently async load only implemented in inline asm
// TODO: currently async load only implemented in inline asm
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
>
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
...
@@ -460,11 +473,17 @@ struct tile_window_with_static_distribution
...
@@ -460,11 +473,17 @@ struct tile_window_with_static_distribution
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
// read from bottom tensor
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
get_bottom_tensor_view
().
template
async_get_vectorized_elements
_raw
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
);
smem
,
bottom_tensor_thread_coord
,
pre_nop_
);
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
@@ -605,6 +624,66 @@ struct tile_window_with_static_distribution
...
@@ -605,6 +624,66 @@ struct tile_window_with_static_distribution
});
});
}
}
template
<
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
}
// move thread's botom tensor coordiante
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// [x0', x1', ... ] ==> [offset]
// also move window-origin
// also move window-origin
...
@@ -619,6 +698,67 @@ struct tile_window_with_static_distribution
...
@@ -619,6 +698,67 @@ struct tile_window_with_static_distribution
});
});
}
}
CK_TILE_DEVICE
void
set_window_origin
(
const
BottomTensorIndex
&
new_window_origin
)
{
window_origin_
=
new_window_origin
;
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
tile_dstr_
.
get_ps_ys_to_xs_adaptor
(),
container_concat
(
detail
::
get_partition_index
(
tile_dstr_
),
array
<
index_t
,
NDimY
>
{
0
}));
#endif
BottomTensorIndex
bottom_tensor_thread_origin_idx_tmp
=
window_origin_
+
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
const
auto
bottom_tensor_thread_coord_tmp
=
make_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_origin_idx_tmp
);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using
Traits
=
load_store_traits
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
window_adaptor_thread_coord
=
window_adaptor_thread_coord_tmp
;
auto
bottom_tensor_thread_coord
=
bottom_tensor_thread_coord_tmp
;
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_step_between
(
number
<
0
>
{},
number
<
iCoord
*
NumAccessPerCoord
>
{});
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
pre_computed_coords_
(
iCoord
)
=
make_tuple
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
);
});
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{
bottom_tensor_view_
.
init_raw
();
}
// this is the bottom tensor view
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
// [x0', x1', ...] ==> [offset]
BottomTensorView
bottom_tensor_view_
;
BottomTensorView
bottom_tensor_view_
;
...
...
include/ck_tile/core/tensor/update_tile.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
DataType_
>
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
DataType_
>
,
DataType
>
,
"wrong!"
);
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
tile_window
=
make_tile_window
(
tile_window_tmp
.
get_bottom_tensor_view
(),
tile_window_tmp
.
get_window_lengths
(),
tile_window_tmp
.
get_window_origin
(),
tile_dstr
);
tile_window
.
update
(
dstr_tensor
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
>
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
tile_window
.
update
(
dstr_tensor
);
}
}
// namespace ck_tile
include/ck_tile/core/utility/philox_rand.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace
ck_tile
{
// Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh
class
philox
{
public:
CK_TILE_HOST_DEVICE
philox
(
unsigned
long
long
seed_
,
unsigned
long
long
offset_
)
:
seed
(
reinterpret_cast
<
const
uint2
&>
(
seed_
))
{
ull2
*
tmp
=
reinterpret_cast
<
ull2
*>
(
&
counter
);
tmp
->
x
=
offset_
;
}
CK_TILE_HOST_DEVICE
uint4
get_philox_4x32
(
const
unsigned
long
long
subsequence
)
const
{
uint4
counter_
=
counter
;
ull2
*
tmp
=
reinterpret_cast
<
ull2
*>
(
&
counter_
);
tmp
->
y
=
subsequence
;
uint2
key_
=
seed
;
// 7-round philox
#pragma unroll
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
counter_
=
philox_single_round
(
counter_
,
key_
);
key_
.
x
+=
kPhilox10A
;
key_
.
y
+=
kPhilox10B
;
}
uint4
output
=
philox_single_round
(
counter_
,
key_
);
return
output
;
}
CK_TILE_HOST_DEVICE
void
get_random_16x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
)
const
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp_ph
.
x
;
out_tmp
[
1
]
=
tmp_ph
.
y
;
out_tmp
[
2
]
=
tmp_ph
.
z
;
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
private:
struct
ull2
{
uint64_t
x
;
uint64_t
y
;
};
uint4
counter
;
const
uint2
seed
;
CK_TILE_HOST_DEVICE
uint2
mulhilo32
(
const
unsigned
int
a
,
const
unsigned
int
b
)
const
{
uint2
*
res
;
unsigned
long
long
tmp
;
tmp
=
static_cast
<
unsigned
long
long
>
(
a
)
*
b
;
res
=
reinterpret_cast
<
uint2
*>
(
&
tmp
);
return
*
res
;
}
CK_TILE_HOST_DEVICE
uint4
philox_single_round
(
const
uint4
ctr
,
const
uint2
key
)
const
{
uint2
res0
=
mulhilo32
(
kPhiloxSA
,
ctr
.
x
);
uint2
res1
=
mulhilo32
(
kPhiloxSB
,
ctr
.
z
);
uint4
ret
=
{
res1
.
y
^
ctr
.
y
^
key
.
x
,
res1
.
x
,
res0
.
y
^
ctr
.
w
^
key
.
y
,
res0
.
x
};
return
ret
;
}
static
const
unsigned
long
kPhilox10A
=
0x9E3779B9
;
static
const
unsigned
long
kPhilox10B
=
0xBB67AE85
;
static
const
unsigned
long
kPhiloxSA
=
0xD2511F53
;
static
const
unsigned
long
kPhiloxSB
=
0xCD9E8D57
;
};
}
// namespace ck_tile
include/ck_tile/host.hpp
View file @
e6bb1dd7
...
@@ -11,12 +11,15 @@
...
@@ -11,12 +11,15 @@
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
include/ck_tile/host/check_err.hpp
View file @
e6bb1dd7
...
@@ -56,8 +56,9 @@ check_err(const Range& out,
...
@@ -56,8 +56,9 @@ check_err(const Range& out,
}
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
};
...
@@ -114,8 +115,9 @@ check_err(const Range& out,
...
@@ -114,8 +115,9 @@ check_err(const Range& out,
}
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
};
...
@@ -173,8 +175,9 @@ check_err(const Range& out,
...
@@ -173,8 +175,9 @@ check_err(const Range& out,
}
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
};
...
@@ -285,8 +288,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -285,8 +288,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
};
...
@@ -357,8 +361,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -357,8 +361,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
};
...
...
include/ck_tile/host/device_memory.hpp
View file @
e6bb1dd7
...
@@ -27,7 +27,14 @@ struct DeviceMem
...
@@ -27,7 +27,14 @@ struct DeviceMem
DeviceMem
()
:
mpDeviceBuf
(
nullptr
),
mMemSize
(
0
)
{}
DeviceMem
()
:
mpDeviceBuf
(
nullptr
),
mMemSize
(
0
)
{}
DeviceMem
(
std
::
size_t
mem_size
)
:
mMemSize
(
mem_size
)
DeviceMem
(
std
::
size_t
mem_size
)
:
mMemSize
(
mem_size
)
{
{
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
if
(
mMemSize
!=
0
)
{
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
else
{
mpDeviceBuf
=
nullptr
;
}
}
}
void
Realloc
(
std
::
size_t
mem_size
)
void
Realloc
(
std
::
size_t
mem_size
)
{
{
...
@@ -36,7 +43,14 @@ struct DeviceMem
...
@@ -36,7 +43,14 @@ struct DeviceMem
HIP_CHECK_ERROR
(
hipFree
(
mpDeviceBuf
));
HIP_CHECK_ERROR
(
hipFree
(
mpDeviceBuf
));
}
}
mMemSize
=
mem_size
;
mMemSize
=
mem_size
;
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
if
(
mMemSize
!=
0
)
{
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
else
{
mpDeviceBuf
=
nullptr
;
}
}
}
void
*
GetDeviceBuffer
()
const
{
return
mpDeviceBuf
;
}
void
*
GetDeviceBuffer
()
const
{
return
mpDeviceBuf
;
}
std
::
size_t
GetBufferSize
()
const
{
return
mMemSize
;
}
std
::
size_t
GetBufferSize
()
const
{
return
mMemSize
;
}
...
@@ -47,15 +61,18 @@ struct DeviceMem
...
@@ -47,15 +61,18 @@ struct DeviceMem
HIP_CHECK_ERROR
(
HIP_CHECK_ERROR
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
hipMemcpyHostToDevice
));
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
hipMemcpyHostToDevice
));
}
}
else
//
else
{
//
{
throw
std
::
runtime_error
(
"ToDevice with an empty pointer"
);
//
throw std::runtime_error("ToDevice with an empty pointer");
}
//
}
}
}
void
ToDevice
(
const
void
*
p
,
const
std
::
size_t
cpySize
)
const
void
ToDevice
(
const
void
*
p
,
const
std
::
size_t
cpySize
)
const
{
{
HIP_CHECK_ERROR
(
if
(
mpDeviceBuf
)
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
cpySize
,
hipMemcpyHostToDevice
));
{
HIP_CHECK_ERROR
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
cpySize
,
hipMemcpyHostToDevice
));
}
}
}
void
FromDevice
(
void
*
p
)
const
void
FromDevice
(
void
*
p
)
const
{
{
...
@@ -63,14 +80,17 @@ struct DeviceMem
...
@@ -63,14 +80,17 @@ struct DeviceMem
{
{
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
hipMemcpyDeviceToHost
));
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
hipMemcpyDeviceToHost
));
}
}
else
//
else
{
//
{
throw
std
::
runtime_error
(
"FromDevice with an empty pointer"
);
//
throw std::runtime_error("FromDevice with an empty pointer");
}
//
}
}
}
void
FromDevice
(
void
*
p
,
const
std
::
size_t
cpySize
)
const
void
FromDevice
(
void
*
p
,
const
std
::
size_t
cpySize
)
const
{
{
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
if
(
mpDeviceBuf
)
{
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
}
}
}
void
SetZero
()
const
void
SetZero
()
const
{
{
...
@@ -82,13 +102,16 @@ struct DeviceMem
...
@@ -82,13 +102,16 @@ struct DeviceMem
template
<
typename
T
>
template
<
typename
T
>
void
SetValue
(
T
x
)
const
void
SetValue
(
T
x
)
const
{
{
if
(
m
MemSize
%
sizeof
(
T
)
!=
0
)
if
(
m
pDeviceBuf
)
{
{
throw
std
::
runtime_error
(
"wrong! not entire DeviceMem will be set"
);
if
(
mMemSize
%
sizeof
(
T
)
!=
0
)
}
{
throw
std
::
runtime_error
(
"wrong! not entire DeviceMem will be set"
);
}
// TODO: call a gpu kernel to set the value (?)
// TODO: call a gpu kernel to set the value (?)
set_buffer_value
<
T
><<<
1
,
1024
>>>
(
static_cast
<
T
*>
(
mpDeviceBuf
),
x
,
mMemSize
/
sizeof
(
T
));
set_buffer_value
<
T
><<<
1
,
1024
>>>
(
static_cast
<
T
*>
(
mpDeviceBuf
),
x
,
mMemSize
/
sizeof
(
T
));
}
}
}
~
DeviceMem
()
~
DeviceMem
()
{
{
...
...
include/ck_tile/host/host_tensor.hpp
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -156,7 +156,7 @@ struct HostTensorDescriptor
...
@@ -156,7 +156,7 @@ struct HostTensorDescriptor
}
}
const
std
::
vector
<
std
::
size_t
>&
get_lengths
()
const
{
return
mLens
;
}
const
std
::
vector
<
std
::
size_t
>&
get_lengths
()
const
{
return
mLens
;
}
const
std
::
vector
<
std
::
size_t
>&
G
et
S
trides
()
const
{
return
mStrides
;
}
const
std
::
vector
<
std
::
size_t
>&
g
et
_s
trides
()
const
{
return
mStrides
;
}
template
<
typename
...
Is
>
template
<
typename
...
Is
>
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
...
@@ -188,7 +188,7 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
...
@@ -188,7 +188,7 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
for
(
std
::
size_t
i
=
0
;
i
<
a
.
get_num_of_dimension
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
a
.
get_num_of_dimension
();
i
++
)
{
{
new_lengths
[
i
]
=
a
.
get_lengths
()[
new2old
[
i
]];
new_lengths
[
i
]
=
a
.
get_lengths
()[
new2old
[
i
]];
new_strides
[
i
]
=
a
.
G
et
S
trides
()[
new2old
[
i
]];
new_strides
[
i
]
=
a
.
g
et
_s
trides
()[
new2old
[
i
]];
}
}
return
HostTensorDescriptor
(
new_lengths
,
new_strides
);
return
HostTensorDescriptor
(
new_lengths
,
new_strides
);
...
@@ -327,7 +327,7 @@ struct HostTensor
...
@@ -327,7 +327,7 @@ struct HostTensor
decltype
(
auto
)
get_lengths
()
const
{
return
mDesc
.
get_lengths
();
}
decltype
(
auto
)
get_lengths
()
const
{
return
mDesc
.
get_lengths
();
}
decltype
(
auto
)
G
et
S
trides
()
const
{
return
mDesc
.
G
et
S
trides
();
}
decltype
(
auto
)
g
et
_s
trides
()
const
{
return
mDesc
.
g
et
_s
trides
();
}
std
::
size_t
get_num_of_dimension
()
const
{
return
mDesc
.
get_num_of_dimension
();
}
std
::
size_t
get_num_of_dimension
()
const
{
return
mDesc
.
get_num_of_dimension
();
}
...
@@ -481,6 +481,34 @@ struct HostTensor
...
@@ -481,6 +481,34 @@ struct HostTensor
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
}
HostTensor
<
T
>
transpose
(
std
::
vector
<
size_t
>
axes
=
{})
const
{
if
(
axes
.
empty
())
{
axes
.
resize
(
this
->
get_num_of_dimension
());
std
::
iota
(
axes
.
rbegin
(),
axes
.
rend
(),
0
);
}
if
(
axes
.
size
()
!=
mDesc
.
get_num_of_dimension
())
{
throw
std
::
runtime_error
(
"HostTensor::transpose(): size of axes must match tensor dimension"
);
}
std
::
vector
<
size_t
>
tlengths
,
tstrides
;
for
(
const
auto
&
axis
:
axes
)
{
tlengths
.
push_back
(
get_lengths
()[
axis
]);
tstrides
.
push_back
(
get_strides
()[
axis
]);
}
HostTensor
<
T
>
ret
(
*
this
);
ret
.
mDesc
=
HostTensorDescriptor
(
tlengths
,
tstrides
);
return
ret
;
}
HostTensor
<
T
>
transpose
(
std
::
vector
<
size_t
>
axes
=
{})
{
return
const_cast
<
HostTensor
<
T
>
const
*>
(
this
)
->
transpose
(
axes
);
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
end
()
{
return
mData
.
end
();
}
typename
Data
::
iterator
end
()
{
return
mData
.
end
();
}
...
...
include/ck_tile/host/kernel_launch.hpp
View file @
e6bb1dd7
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/timer.hpp"
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <cstddef>
#include <cstddef>
...
@@ -14,153 +15,92 @@ template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename...
...
@@ -14,153 +15,92 @@ template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename...
#if CK_TILE_USE_LAUNCH_BOUNDS
#if CK_TILE_USE_LAUNCH_BOUNDS
__launch_bounds__
(
MaxThreadPerBlock
,
MinBlockPerCu
)
__launch_bounds__
(
MaxThreadPerBlock
,
MinBlockPerCu
)
#endif
#endif
__global__
void
kentry
(
Kernel
f
,
Args
...
args
)
__global__
void
kentry
(
Args
...
args
)
{
{
f
(
args
...);
Kernel
{}
(
args
...);
}
}
template
<
typename
...
Args
,
typename
F
>
//
CK_TILE_HOST
float
launch_and_time_kernel
(
const
stream_config
&
s
,
// return a anonymous functor(lambda) to be called later
F
kernel
,
// the KernelImpl should be a class without non-static data member, or let's say
dim3
grid_dim
,
// can be instantiate with "KernelImpl{}"
dim3
block_dim
,
//
std
::
size_t
lds_byte
,
// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
Args
...
args
)
//
template
<
int
MaxThreadPerBlock
=
CK_TILE_MAX_THREAD_PER_BLOCK
,
int
MinBlockPerCu
=
CK_TILE_MIN_BLOCK_PER_CU
,
typename
KernelImpl
,
typename
...
Args
>
CK_TILE_HOST
auto
make_kernel
(
KernelImpl
/*f*/
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
{
#if CK_TILE_TIME_KERNEL
const
auto
kernel
=
kentry
<
MaxThreadPerBlock
,
MinBlockPerCu
,
KernelImpl
,
Args
...
>
;
if
(
s
.
time_kernel_
)
{
// warm up
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
++
i
)
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
const
int
nrepeat
=
s
.
nrepeat_
;
hipEvent_t
start
,
stop
;
HIP_CHECK_ERROR
(
hipEventCreate
(
&
start
));
HIP_CHECK_ERROR
(
hipEventCreate
(
&
stop
));
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
HIP_CHECK_ERROR
(
hipEventRecord
(
start
,
s
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
HIP_CHECK_ERROR
(
hipEventRecord
(
stop
,
s
.
stream_id_
));
HIP_CHECK_ERROR
(
hipEventSynchronize
(
stop
));
float
total_time
=
0
;
HIP_CHECK_ERROR
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
return
total_time
/
nrepeat
;
return
[
=
](
const
stream_config
&
s
)
{
}
else
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
};
return
0
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
#endif
}
}
template
<
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
// clang-format off
CK_TILE_HOST
float
launch_and_time_kernel_with_preprocess
(
const
stream_config
&
s
,
/*
PreProcessFunc
preprocess
,
* launch_kernel()
F
kernel
,
*
dim3
grid_dim
,
* this is the function to launch arbitrary number of kernels with optional timer(selected by stream_config)
dim3
block_dim
,
* the callables should have signature as "operator()(const stream_config& s){ ... }" to call
std
::
size_t
lds_byte
,
*
Args
...
args
)
* the simplest way is pass in a lambda function, with "[=](const stream_config& s){ call_your_kernel_here() }"
* as signature, for the callable (pay attention to the capture list)
*
* e.g.
* ck_tile::launch_kernel(s,
* [=](const stream_config& s){ hipMemset(ptr, 0, size) },
* [=](const stream_config& s){ some_kernel<<<grids, blocks>>>(arg); }
* );
*
* if you use ck_tile kernel, or similiar to this style (structure with "static __device__ operator()(...){}")
* you can pass your kernel to ck_tile::make_kernel(), which will create a anonymous functor for you,
* then pass it to ck_tile::launch_kernel()
*
* e.g.
* ck_tile::launch_kernel(s,
* ck_tile::make_kernel<T0, B0>(kernel_0{}, grids0, blocks0, 0, kargs0),
* ck_tile::make_kernel<T0, B1>(kernel_1{}, grids1, blocks1, 0, kargs1),
* ...);
**/
// clang-format on
template
<
typename
...
Callables
>
CK_TILE_HOST
float
launch_kernel
(
const
stream_config
&
s
,
Callables
...
callables
)
{
{
#if CK_TILE_TIME_KERNEL
// clang-format off
if
(
s
.
time_kernel_
)
if
(
!
s
.
time_kernel_
)
{
{
(
callables
(
s
),...);
hip_check_error
(
hipGetLastError
());
#if CK_TILE_DEBUG_LOG
return
0
;
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
}
__func__
,
if
(
s
.
is_gpu_timer_
)
{
grid_dim
.
x
,
gpu_timer
timer
{};
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up 1 time
\n
"
);
#endif
// warm up
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
const
int
nrepeat
=
10
;
#if CK_TILE_DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
hipEvent_t
start
,
stop
;
HIP_CHECK_ERROR
(
hipEventCreate
(
&
start
));
HIP_CHECK_ERROR
(
hipEventCreate
(
&
stop
));
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
HIP_CHECK_ERROR
(
hipEventRecord
(
start
,
s
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
// warmup
{
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
HIP_CHECK_ERROR
(
hipEventRecord
(
stop
,
s
.
stream_id_
));
timer
.
start
(
s
.
stream_id_
);
HIP_CHECK_ERROR
(
hipEventSynchronize
(
stop
));
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
timer
.
stop
(
s
.
stream_id_
);
float
total_time
=
0
;
return
timer
.
duration
()
/
s
.
nrepeat_
;
}
else
{
cpu_timer
timer
{};
HIP_CHECK_ERROR
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
// warmup
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
return
total_time
/
nrepeat
;
timer
.
start
(
s
.
stream_id_
);
}
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
else
timer
.
stop
(
s
.
stream_id_
);
{
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
return
timer
.
duration
()
/
s
.
nrepeat_
;
}
}
#else
// clang-format on
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
#endif
}
}
template
<
int
MaxThreadPerBlock
=
CK_TILE_MAX_THREAD_PER_BLOCK
,
int
MinBlockPerCu
=
CK_TILE_MIN_BLOCK_PER_CU
,
typename
KernelImpl
,
typename
...
Args
>
CK_TILE_HOST
float
launch_kernel
(
const
stream_config
&
s
,
KernelImpl
kernel_impl
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
dynamic_smem_byte
,
Args
...
args
)
{
const
auto
kernel
=
kentry
<
MaxThreadPerBlock
,
MinBlockPerCu
,
KernelImpl
,
Args
...
>
;
return
launch_and_time_kernel
(
s
,
kernel
,
grid_dim
,
block_dim
,
dynamic_smem_byte
,
kernel_impl
,
args
...);
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/host/reference/reference_batched_dropout.hpp
0 → 100644
View file @
e6bb1dd7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace
ck_tile
{
template
<
typename
DataType
,
typename
RandValOutputDataType
>
CK_TILE_HOST
void
reference_batched_dropout
(
HostTensor
<
DataType
>&
in_out_b_m_n
,
const
HostTensor
<
RandValOutputDataType
>&
randval_b_m_n
,
const
uint8_t
&
p_undrop_in_uint8_t
,
const
float
scale
)
{
const
int
N
=
in_out_b_m_n
.
mDesc
.
get_lengths
()[
2
];
auto
f
=
[
&
](
auto
batch
,
auto
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
float
tmp
=
ck_tile
::
type_convert
<
float
>
(
in_out_b_m_n
(
batch
,
m
,
n
))
*
scale
;
in_out_b_m_n
(
batch
,
m
,
n
)
=
randval_b_m_n
(
batch
,
m
,
n
)
<=
p_undrop_in_uint8_t
?
ck_tile
::
type_convert
<
DataType
>
(
tmp
)
:
DataType
(
0
);
}
};
make_ParallelTensorFunctor
(
f
,
randval_b_m_n
.
mDesc
.
get_lengths
()[
0
],
randval_b_m_n
.
mDesc
.
get_lengths
()[
1
])(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
Prev
1
…
10
11
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment