Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3a9ab7a7
Commit
3a9ab7a7
authored
Mar 23, 2023
by
guangzlu
Browse files
changed random number generate into hiprand
parent
7c4c31cf
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
127 additions
and
17 deletions
+127
-17
CMakeLists.txt
CMakeLists.txt
+2
-0
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward.cpp
...cale_softmax_gemm/batched_multihead_attention_forward.cpp
+2
-1
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
..._softmax_gemm/run_batched_multihead_attention_forward.inc
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
+99
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
...vice_batched_multihead_attention_forward_xdl_cshuffle.hpp
+13
-8
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
...wise_batched_multihead_attention_forward_xdl_cshuffle.hpp
+9
-6
No files found.
CMakeLists.txt
View file @
3a9ab7a7
...
@@ -7,6 +7,8 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
...
@@ -7,6 +7,8 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
enable_testing
()
enable_testing
()
add_definitions
(
-w
)
set
(
ROCM_SYMLINK_LIBS OFF
)
set
(
ROCM_SYMLINK_LIBS OFF
)
find_package
(
ROCM REQUIRED PATHS /opt/rocm
)
find_package
(
ROCM REQUIRED PATHS /opt/rocm
)
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_multihead_attention_forward.cpp
View file @
3a9ab7a7
...
@@ -49,7 +49,8 @@ using B1DataType = DataType;
...
@@ -49,7 +49,8 @@ using B1DataType = DataType;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
DataType
;
using
CDataType
=
DataType
;
using
ZDataType
=
U16
;
//using ZDataType = U16;
using
ZDataType
=
F32
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_multihead_attention_forward.inc
View file @
3a9ab7a7
...
@@ -27,7 +27,7 @@ int run(int argc, char* argv[])
...
@@ -27,7 +27,7 @@ int run(int argc, char* argv[])
float
p_drop
=
0.1
;
float
p_drop
=
0.1
;
float
p_dropout
=
1
-
p_drop
;
float
p_dropout
=
1
-
p_drop
;
uint16_t
p_dropout_in_
16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
uint16_t
p_dropout_in_
float
=
p_dropout
;
//
uint16_t(std::floor(p_dropout * 65535.0));
float
rp_dropout
=
1.0
/
p_dropout
;
float
rp_dropout
=
1.0
/
p_dropout
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
seed
=
1
;
const
unsigned
long
long
offset
=
0
;
const
unsigned
long
long
offset
=
0
;
...
@@ -322,7 +322,7 @@ int run(int argc, char* argv[])
...
@@ -322,7 +322,7 @@ int run(int argc, char* argv[])
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout
=
ReferenceDropoutInstance
{};
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_invoker
=
ref_dropout
.
MakeInvoker
();
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
auto
ref_dropout_argment
=
ref_dropout
.
MakeArgument
(
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
16bits
,
rp_dropout
);
z_g_m_n
,
a1_g_m_n
,
a1_g_m_n_drop
,
p_dropout_in_
float
,
rp_dropout
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
ref_dropout_invoker
.
Run
(
ref_dropout_argment
);
// gemm1
// gemm1
...
...
include/ck/tensor_operation/gpu/block/blockwise_dropout.hpp
View file @
3a9ab7a7
...
@@ -3,6 +3,9 @@
...
@@ -3,6 +3,9 @@
#pragma once
#pragma once
#include "hiprand.h"
#include "hiprand_kernel.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/utility/philox_rand.hpp"
...
@@ -11,6 +14,23 @@ namespace ck {
...
@@ -11,6 +14,23 @@ namespace ck {
template
<
typename
DataType
,
typename
ThreadSliceDesc_M_K
>
template
<
typename
DataType
,
typename
ThreadSliceDesc_M_K
>
struct
BlockwiseDropout
struct
BlockwiseDropout
{
{
//__host__ __device__ BlockwiseDropout(){}
//
//__host__ __device__ BlockwiseDropout(ushort p_dropout_in_16bits, DataType p_dropout_to_rescale)
//{
// p_dropout_16bits = p_dropout_in_16bits;
// p_dropout_rescale = p_dropout_to_rescale;
//}
//
//__host__ __device__ BlockwiseDropout(float p_dropout_in_float, DataType p_dropout_to_rescale)
//{
// p_dropout_float = p_dropout_in_float;
// p_dropout_rescale = p_dropout_to_rescale;
//}
//~BlockwiseDropout(){}
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
MRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
);
static
constexpr
index_t
MRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
);
...
@@ -50,6 +70,48 @@ struct BlockwiseDropout
...
@@ -50,6 +70,48 @@ struct BlockwiseDropout
});
});
}
}
template
<
typename
CThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
hiprandState_t
&
state
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
;
int
hiprand_calls
=
tmp_size
/
8
;
float
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
hiprand_calls
;
i
++
)
{
float
tmp_rand
=
hiprand_uniform
(
&
state
);
tmp
[
i
]
=
tmp_rand
;
tmp
[
i
+
1
]
=
tmp_rand
;
tmp
[
i
+
2
]
=
tmp_rand
;
tmp
[
i
+
3
]
=
tmp_rand
;
tmp
[
i
+
4
]
=
tmp_rand
;
tmp
[
i
+
5
]
=
tmp_rand
;
tmp
[
i
+
6
]
=
tmp_rand
;
tmp
[
i
+
7
]
=
tmp_rand
;
}
block_sync_lds
();
int
tmp_index
=
0
;
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
execute_dropout
(
tmp
[
tmp_index
]
<=
p_dropout_float
,
in_thread_buf
(
offset
));
tmp_index
=
tmp_index
+
1
;
});
});
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
=
false
>
__host__
__device__
void
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
ZThreadBuffer
&
z_thread_buf
)
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
ck
::
philox
&
ph
,
ZThreadBuffer
&
z_thread_buf
)
...
@@ -122,7 +184,44 @@ struct BlockwiseDropout
...
@@ -122,7 +184,44 @@ struct BlockwiseDropout
});
});
}
}
template
<
typename
CThreadBuffer
,
typename
ZThreadBuffer
,
bool
using_sign_bit
,
typename
N0
,
typename
Offset
>
__host__
__device__
void
ApplyDropout
(
CThreadBuffer
&
in_thread_buf
,
hiprandState_t
&
state
,
ZThreadBuffer
&
z_thread_buf
)
{
auto
execute_dropout
=
[
&
](
bool
keep
,
DataType
val
)
{
if
constexpr
(
using_sign_bit
)
return
keep
?
val
:
-
val
;
else
return
keep
?
val
*
p_dropout_rescale
:
float
(
0
);
};
constexpr
int
tmp_size
=
MRepeat
*
KRepeat
/
N0
{}.
value
;
int
philox_calls
=
tmp_size
;
ushort
tmp
[
tmp_size
];
for
(
int
i
=
0
;
i
<
philox_calls
;
i
++
)
{
tmp
[
i
]
=
hiprand_uniform
(
&
state
);
}
block_sync_lds
();
constexpr
auto
iOffset
=
Number
<
tmp_size
>
{}
*
Offset
{};
static_for
<
0
,
tmp_size
,
1
>
{}([
&
](
auto
i
)
{
in_thread_buf
(
i
+
iOffset
)
=
execute_dropout
(
tmp
[
i
.
value
]
<=
p_dropout_float
,
in_thread_buf
(
i
+
iOffset
));
z_thread_buf
(
i
)
=
tmp
[
i
.
value
];
});
}
ushort
p_dropout_16bits
;
ushort
p_dropout_16bits
;
float
p_dropout_float
;
DataType
p_dropout_rescale
;
DataType
p_dropout_rescale
;
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
3a9ab7a7
...
@@ -6,8 +6,11 @@
...
@@ -6,8 +6,11 @@
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "hiprand.h"
#include "hiprand_kernel.h"
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
//
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
...
@@ -74,7 +77,7 @@ __global__ void
...
@@ -74,7 +77,7 @@ __global__ void
const
index_t
batch_count
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
C0MatrixMask
c0_matrix_mask
,
const
ushor
t
p_dropout_in_
16bits
,
const
floa
t
p_dropout_in_
float
,
const
GemmAccDataType
p_dropout_rescale
,
const
GemmAccDataType
p_dropout_rescale
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
)
...
@@ -99,7 +102,9 @@ __global__ void
...
@@ -99,7 +102,9 @@ __global__ void
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
hiprandState_t
state
;
hiprand_init
(
seed
,
global_thread_id
,
offset
,
&
state
);
//ck::philox ph(seed, global_thread_id, offset);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
,
IsLseStoring
>(
p_a_grid
+
a_batch_offset
,
p_a_grid
+
a_batch_offset
,
...
@@ -122,9 +127,9 @@ __global__ void
...
@@ -122,9 +127,9 @@ __global__ void
lse_grid_desc_m
,
lse_grid_desc_m
,
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
p_dropout_in_
16bits
,
p_dropout_in_
float
,
p_dropout_rescale
,
p_dropout_rescale
,
ph
);
state
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -591,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -591,7 +596,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
is_dropout_
=
p_dropout
>
0.0
;
//
is_dropout_
=
p_dropout
>
0.0
;
//
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_
=
1.
f
-
p_dropout
;
p_dropout_in_
16bits_
=
uint16_t
(
std
::
floor
(
p_dropout_
*
65535.0
));
p_dropout_in_
float_
=
p_dropout_
;
//
uint16_t(std::floor(p_dropout_ * 65535.0));
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_
=
1.
f
/
p_dropout_
;
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
p_dropout_rescale_
=
type_convert
<
GemmAccDataType
>
(
p_dropout_
);
...
@@ -673,7 +678,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -673,7 +678,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_dropout_
;
float
p_dropout_
;
ushort
p_dropout_in_
16bits
_
;
ushort
p_dropout_in_
float
_
;
GemmAccDataType
p_dropout_rescale_
;
GemmAccDataType
p_dropout_rescale_
;
unsigned
long
long
seed_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
unsigned
long
long
offset_
;
...
@@ -757,7 +762,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -757,7 +762,7 @@ struct DeviceBatchedMultiheadAttentionForward_Xdl_CShuffle
arg
.
batch_count_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
c0_matrix_mask_
,
arg
.
p_dropout_in_
16bits
_
,
arg
.
p_dropout_in_
float
_
,
arg
.
p_dropout_rescale_
,
arg
.
p_dropout_rescale_
,
arg
.
seed_
,
arg
.
seed_
,
arg
.
offset_
);
arg
.
offset_
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle.hpp
View file @
3a9ab7a7
...
@@ -3,8 +3,11 @@
...
@@ -3,8 +3,11 @@
#pragma once
#pragma once
#include "hiprand.h"
#include "hiprand_kernel.h"
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
//
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
@@ -443,9 +446,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -443,9 +446,9 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
LSEGridDesc_M
&
lse_grid_desc_m
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
ushor
t
p_dropout_in_
16bits
,
const
floa
t
p_dropout_in_
float
,
FloatGemmAcc
p_dropout_rescale
,
FloatGemmAcc
p_dropout_rescale
,
ck
::
philox
ph
)
hiprandState_t
state
)
{
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
@@ -792,7 +795,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -792,7 +795,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
decltype
(
thread_slice_desc_m_n
)
>
{};
decltype
(
thread_slice_desc_m_n
)
>
{};
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
auto
blockwise_dropout
=
BlockwiseDropout
<
FloatGemmAcc
,
decltype
(
thread_slice_desc_m_n
)
>
{
p_dropout_in_
16bits
,
p_dropout_rescale
};
0
,
p_dropout_in_
float
,
p_dropout_rescale
};
const
index_t
num_gemm1_k_block_outer_loop
=
const
index_t
num_gemm1_k_block_outer_loop
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
...
@@ -1013,7 +1016,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1013,7 +1016,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
false
,
false
,
decltype
(
n0
),
decltype
(
n0
),
decltype
(
i
)>(
decltype
(
i
)>(
acc_thread_buf
,
ph
,
z_tenor_buffer
);
acc_thread_buf
,
state
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
@@ -1037,7 +1040,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1037,7 +1040,7 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
// ignore = z_grid_buf;
// ignore = z_grid_buf;
// P_dropped
// P_dropped
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
false
>(
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
acc_thread_buf
),
false
>(
acc_thread_buf
,
ph
);
acc_thread_buf
,
state
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment