Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bcbeed99
Commit
bcbeed99
authored
Aug 31, 2023
by
danyao12
Browse files
dropout deviceop
parent
7d9a6cc2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
480 additions
and
53 deletions
+480
-53
include/ck/tensor_operation/gpu/device/impl/device_batched_dropout.hpp
...nsor_operation/gpu/device/impl/device_batched_dropout.hpp
+438
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_dropout.hpp
...ck/tensor_operation/gpu/grid/gridwise_batched_dropout.hpp
+42
-53
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_dropout.hpp
0 → 100644
View file @
bcbeed99
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <numeric>
#include "ck/utility/common_header.hpp"
#include "ck/utility/philox_rand.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/masking_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_dropout.hpp"
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ZDataType
,
typename
AGridDesc_AK0_M_AK1
,
typename
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_dropout
(
ZDataType
*
__restrict__
p_z_grid
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
,
const
index_t
raw_m_padded
,
const
index_t
raw_n_padded
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
z_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetZBasePtr
(
g_idx
)));
ck
::
philox
ph
(
seed
,
0
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
GridwiseDropout
::
Run
(
z_matrix_ptr
,
a_grid_desc_ak0_m_ak1
,
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
block_2_ctile_map
,
ph
,
z_random_matrix_offset
,
raw_n_padded
);
#else
ignore
=
p_z_grid
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
seed
;
ignore
=
offset
;
ignore
=
raw_m_padded
;
ignore
=
raw_n_padded
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
// NumDimGemm1N
typename
GemmDataType
,
typename
ZDataType
,
typename
GemmAccDataType
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
BSpec
,
TensorSpecialization
B1Spec
,
TensorSpecialization
CSpec
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
// Gemm0NPerBlock
index_t
KPerBlock
,
// Gemm0KPerBlock
index_t
Gemm1NPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
B1K1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
>
struct
DeviceBatchedDropout
:
public
BaseOperator
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
&&
NumDimO
>
0
,
"Number of dimension must be greater than 0"
);
using
DeviceOp
=
DeviceBatchedDropout
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
>
,
Sequence
<
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
>
,
GemmSpec
,
ASpec
,
BSpec
,
B1Spec
,
CSpec
>
;
/*
Descriptors for inputs:
Q, K, V, Y, dY, per-row softmax stats
Descriptors for outputs:
dQ, dK, dV
*/
// Q in Gemm A position
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
)
{
return
Transform
::
MakeAGridDescriptor_AK0_M_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
),
Number
<
AK1
>
{});
}
// Z in Gemm0 C position
static
auto
MakeZGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
)
{
return
Transform
::
MakeCGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
({},
{}));
using
ZGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
KGridDesc_N_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_N_K
({},
{}));
using
ZGridDesc_M_N
=
decltype
(
MakeZGridDescriptor_M_N
({},
{}));
struct
ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch
()
{}
ComputeBasePtrOfStridedBatch
(
const
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n
)
:
z_grid_desc_g_m_n_
(
z_grid_desc_g_m_n
)
{
}
__host__
__device__
constexpr
long_index_t
GetZBasePtr
(
index_t
g_idx
)
const
{
return
z_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
private:
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
};
using
GridwiseDropout
=
GridwiseBatchedDropout
<
ZDataType
,
GemmDataType
,
GemmAccDataType
,
AGridDesc_AK0_M_AK1
,
KGridDesc_N_K
,
ZGridDesc_M_N
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
ZDataType
*
p_z_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
:
p_z_grid_
{
p_z_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
z_grid_desc_m_n_
{
MakeZGridDescriptor_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
k_grid_desc_n_k_
{
Transform
::
MakeB0GridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
z_grid_desc_g_m_n_
{
Transform
::
MakeCGridDescriptor_G_M_N
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
)},
block_2_ctile_map_
{
GridwiseDropout
::
MakeDefaultBlock2CTileMap
(
k_grid_desc_n_k_
)},
raw_lengths_mz_nz_kz_gemm1nz_
{
a_gs_ms_ks_lengths
[
NumDimG
+
NumDimM
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
-
1
],
b_gs_ns_ks_lengths
[
NumDimG
+
NumDimN
+
NumDimK
-
1
]},
batch_count_
{
c_grid_desc_g_m_n_
.
GetLength
(
I0
)}
{
compute_base_ptr_of_batch_
=
ComputeBasePtrOfStridedBatch
(
z_grid_desc_g_m_n_
);
seed_
=
std
::
get
<
0
>
(
seeds
);
offset_
=
std
::
get
<
1
>
(
seeds
);
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
=
GridwiseDropout
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
z_grid_desc_m_n_
);
// Print();
m_raw_padded_
=
GridwiseDropout
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
0
]);
n_raw_padded_
=
GridwiseDropout
::
GetPaddedSize
(
raw_lengths_mz_nz_kz_gemm1nz_
[
1
]);
}
void
Print
()
const
{
std
::
cout
<<
"a_grid_desc_g_m_k_: "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I0
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I1
)
<<
", "
<<
a_grid_desc_g_m_k_
.
GetLength
(
I2
)
<<
'\n'
;
// a_grid_desc_g_m_k_.Print();
}
// pointers
ZDataType
*
p_z_grid_
;
// tensor descriptor
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
KGridDesc_N_K
k_grid_desc_n_k_
;
// batch offsets
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
typename
GridwiseDropout
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
;
// block-to-c-tile map
typename
GridwiseDropout
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
// For robust IsSupportedArgument() check
std
::
vector
<
index_t
>
raw_lengths_mz_nz_kz_gemm1nz_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
index_t
m_raw_padded_
;
index_t
n_raw_padded_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
DeviceOp
::
IsSupportedArgument
(
arg
))
{
throw
std
::
runtime_error
(
"wrong! unsupported argument"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
k_grid_desc_n_k_
)
*
arg
.
batch_count_
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
]()
{
const
auto
kernel
=
kernel_batched_dropout
<
ZDataType
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
typename
GridwiseDropout
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
,
typename
GridwiseDropout
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_z_grid_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
seed_
,
arg
.
offset_
,
arg
.
m_raw_padded_
,
arg
.
n_raw_padded_
);
};
ave_time
=
launch_kernel
();
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
#if DEBUG_LOG
arg
.
Print
();
#endif
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
return
false
;
}
return
GridwiseDropout
::
CheckValidity
();
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
ZDataType
*
p_z
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
{
return
Argument
{
p_z
,
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
seeds
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
// FIXME: constness
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_z
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
z_gs_ms_ns_strides
,
std
::
tuple
<
unsigned
long
long
,
unsigned
long
long
>
seeds
)
// override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
ZDataType
*>
(
p_z
),
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
,
seeds
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
// override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchedDropout"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
<<
"B0Spec"
<<
getTensorSpecializationString
(
BSpec
)
<<
", "
<<
"B1Spec"
<<
getTensorSpecializationString
(
B1Spec
)
<<
", "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_batched_dropout.hpp
View file @
bcbeed99
...
@@ -26,6 +26,7 @@ template <typename ZDataType,
...
@@ -26,6 +26,7 @@ template <typename ZDataType,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
Gemm1NPerBlock
,
index_t
AK1Value
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
BK1Value
,
index_t
MPerXdl
,
index_t
MPerXdl
,
...
@@ -113,21 +114,18 @@ struct GridwiseBatchedDropout
...
@@ -113,21 +114,18 @@ struct GridwiseBatchedDropout
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
I1
>
{}
*
AK1
,
AK1
,
I1
));
make_tuple
(
Number
<
MPerBlock
+
I1
>
{}
*
AK1
,
AK1
,
I1
));
}
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
{
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
I1
>
{}
*
BK1
,
BK1
,
I1
));
make_tuple
(
Number
<
NPerBlock
+
I1
>
{}
*
BK1
,
BK1
,
I1
));
}
}
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
()
CheckValidity
()
{
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
@@ -141,7 +139,7 @@ struct GridwiseBatchedDropout
...
@@ -141,7 +139,7 @@ struct GridwiseBatchedDropout
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
KGridDesc_N_K
&
k_grid_desc_n_k
)
MakeDefaultBlock2CTileMap
(
const
KGridDesc_N_K
&
k_grid_desc_n_k
)
{
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
NPerBlock
,
K
PerBlock
,
KGridDesc_N_K
>
(
return
BlockToCTileMap_M00_N0_M01Adapt
<
NPerBlock
,
Gemm1N
PerBlock
,
KGridDesc_N_K
>
(
k_grid_desc_n_k
);
k_grid_desc_n_k
);
}
}
...
@@ -151,7 +149,6 @@ struct GridwiseBatchedDropout
...
@@ -151,7 +149,6 @@ struct GridwiseBatchedDropout
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
=
remove_cvref_t
<
decltype
(
using
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
ZGridDesc_M_N
{}))
>
;
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
(
ZGridDesc_M_N
{}))
>
;
// S Gemm
// S Gemm
struct
Gemm0
struct
Gemm0
{
{
...
@@ -205,15 +202,14 @@ struct GridwiseBatchedDropout
...
@@ -205,15 +202,14 @@ struct GridwiseBatchedDropout
};
};
template
<
typename
Block2CTileMap
>
template
<
typename
Block2CTileMap
>
__device__
static
void
__device__
static
void
Run
(
ZDataType
*
__restrict__
p_z_grid
,
Run
(
ZDataType
*
__restrict__
p_z_grid
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
QGridDesc_K0_M_K1
&
q_grid_desc_k0_m_k1
,
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
const
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3
&
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
ck
::
philox
&
ph
,
ck
::
philox
&
ph
,
const
index_t
z_random_matrix_offset
,
const
index_t
z_random_matrix_offset
,
const
index_t
raw_n_padded
)
const
index_t
raw_n_padded
)
{
{
// divide block work by [N, K]
// divide block work by [N, K]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
...
@@ -294,7 +290,7 @@ struct GridwiseBatchedDropout
...
@@ -294,7 +290,7 @@ struct GridwiseBatchedDropout
1
,
// DstScalarStrideInVector
1
,
// DstScalarStrideInVector
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
true
>
{
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_multi_index
(
num_gemm0_m_block_outer_loop
-
1
,
// MBlockId
make_multi_index
(
num_gemm0_m_block_outer_loop
-
1
,
// MBlockId
block_work_idx
[
I0
],
// NBlockId
block_work_idx
[
I0
],
// NBlockId
0
,
// MRepeat
0
,
// MRepeat
0
,
// NRepeat
0
,
// NRepeat
wave_id
[
I0
],
// MWaveId
wave_id
[
I0
],
// MWaveId
...
@@ -308,7 +304,7 @@ struct GridwiseBatchedDropout
...
@@ -308,7 +304,7 @@ struct GridwiseBatchedDropout
// 8d thread_desc in thread scope
// 8d thread_desc in thread scope
constexpr
auto
c_thread_lengths
=
constexpr
auto
c_thread_lengths
=
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
s_blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
// 8d block_desc in block scope
// 8d block_desc in block scope
constexpr
auto
c_block_lengths
=
constexpr
auto
c_block_lengths
=
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
s_blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
().
GetLengths
();
...
@@ -324,15 +320,15 @@ struct GridwiseBatchedDropout
...
@@ -324,15 +320,15 @@ struct GridwiseBatchedDropout
// works like multi-dimension static_for (static_ford), but provides both the linear
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
// index as well as n-d index
using
Acc0TileIterator
=
SpaceFillingCurve
<
using
Acc0TileIterator
=
decltype
(
c_thread_lengths
),
SpaceFillingCurve
<
decltype
(
c_thread_lengths
),
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
arithmetic_sequence_gen
<
0
,
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
typename
uniform_sequence_gen
<
c_thread_lengths
.
Size
(),
1
>::
type
,
false
>
;
// SnakeCurved
false
>
;
// SnakeCurved
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
constexpr
auto
block_idx_to_m_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
)),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_unmerge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<
1
,
3
,
7
>
{}));
...
@@ -346,36 +342,29 @@ struct GridwiseBatchedDropout
...
@@ -346,36 +342,29 @@ struct GridwiseBatchedDropout
// save z to global
// save z to global
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
acc0_thread_idx
=
Acc0TileIterator
::
GetIndex
(
I0
)
+
acc0_thread_origin
;
auto
m_local
=
auto
m_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I0
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
n_local
=
block_idx_to_m_n_adaptor
.
CalculateBottomIndex
(
acc0_thread_idx
)[
I1
];
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_tile_id
=
z_random_matrix_offset
+
auto
global_tile_id
=
z_random_matrix_offset
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
m_global
/
DropoutTile
)
*
DropoutTile
*
raw_n_padded
+
(
n_global
/
DropoutTile
)
*
DropoutTile
;
(
n_global
/
DropoutTile
)
*
DropoutTile
;
auto
global_elem_id
=
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
auto
global_elem_id
=
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
global_tile_id
+
(
wave_m_n_id
[
I0
]
*
M4
)
+
(
n_global
%
DropoutTile
)
*
raw_n_padded
;
blockwise_dropout
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tensor_buffer
),
decltype
(
z_tensor_buffer
),
decltype
(
DropoutTile
),
decltype
(
DropoutTile
),
true
>(
true
>(
s_slash_p_thread_buf
,
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tensor_buffer
,
raw_n_padded
);
ph
,
global_elem_id
,
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_tensor_buffer
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
raw_n_padded
);
z_tensor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_thread_copy_vgpr_to_global
.
Run
(
z_grid_buf
);
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tensor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_grid_buf
);
// move slice window
// move slice window
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
z_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment