Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
88eee5fe
Commit
88eee5fe
authored
Oct 21, 2021
by
Jeff Daily
Browse files
updates to MHA, compilation still broken
parent
1fd257e2
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
91 additions
and
97 deletions
+91
-97
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cpp.cpp
...rc/multihead_attn/additive_masked_softmax_dropout_cpp.cpp
+0
-0
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
...rc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
+1
-1
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
...contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
+0
-0
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+1
-1
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp
...src/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp
+0
-0
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cpp.cpp
...ontrib/csrc/multihead_attn/masked_softmax_dropout_cpp.cpp
+0
-0
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cpp.cpp
...ihead_attn/self_multihead_attn_bias_additive_mask_cpp.cpp
+0
-0
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+1
-1
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
...trib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
+0
-0
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+1
-1
apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp
apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp
+0
-0
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+1
-1
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cpp.cpp
.../csrc/multihead_attn/self_multihead_attn_norm_add_cpp.cpp
+0
-0
apex/contrib/csrc/multihead_attn/softmax.h
apex/contrib/csrc/multihead_attn/softmax.h
+25
-18
setup.py
setup.py
+61
-74
No files found.
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp
→
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout
_cpp
.cpp
View file @
88eee5fe
File moved
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
View file @
88eee5fe
...
...
@@ -5,7 +5,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp
→
apex/contrib/csrc/multihead_attn/encdec_multihead_attn
_cpp
.cpp
View file @
88eee5fe
File moved
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
88eee5fe
...
...
@@ -6,7 +6,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp
→
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add
_cpp
.cpp
View file @
88eee5fe
File moved
apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp
→
apex/contrib/csrc/multihead_attn/masked_softmax_dropout
_cpp
.cpp
View file @
88eee5fe
File moved
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp
→
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask
_cpp
.cpp
View file @
88eee5fe
File moved
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
88eee5fe
...
...
@@ -5,7 +5,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp
→
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias
_cpp
.cpp
View file @
88eee5fe
File moved
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
88eee5fe
...
...
@@ -5,7 +5,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp
→
apex/contrib/csrc/multihead_attn/self_multihead_attn
_cpp
.cpp
View file @
88eee5fe
File moved
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
View file @
88eee5fe
...
...
@@ -5,7 +5,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
//
#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp
→
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add
_cpp
.cpp
View file @
88eee5fe
File moved
apex/contrib/csrc/multihead_attn/softmax.h
View file @
88eee5fe
...
...
@@ -11,7 +11,14 @@
#include <cuda_fp16.h>
#include <cmath>
#ifdef __HIP_PLATFORM_HCC__
#define WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width)
#else
#define WARP_SHFL_XOR __shfl_xor_sync
#endif
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
...
...
@@ -127,7 +134,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -152,7 +159,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -351,7 +358,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -375,7 +382,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
philox_args
);
...
...
@@ -505,7 +512,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -529,7 +536,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
curandStatePhilox4_32_10_t
state
;
...
...
@@ -765,7 +772,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -790,7 +797,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -1020,7 +1027,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -1045,7 +1052,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -1243,7 +1250,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -1268,7 +1275,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -1385,7 +1392,7 @@ bool dispatch_time_masked_softmax(output_t *dst, const input_t *src, const uint8
return
false
;
}
int
log2_ceil_native
(
int
value
)
{
static
int
log2_ceil_native
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
...
...
@@ -1394,7 +1401,7 @@ int log2_ceil_native(int value) {
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if CUDA_VERSION >= 9000
#if CUDA_VERSION >= 9000
&& !defined(__HIP_PLATFORM_HCC__)
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
...
...
@@ -1835,7 +1842,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
__shfl_xor_sync
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -1860,7 +1867,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -2305,7 +2312,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, con
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -2516,7 +2523,7 @@ __global__ void masked_softmax_warp_backward(__half *gradInput, const __half *gr
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
__shfl_xor_sync
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
setup.py
View file @
88eee5fe
...
...
@@ -39,7 +39,6 @@ if IS_ROCM_PYTORCH:
else
:
rocm_include_dirs
=
[]
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
)]
+
rocm_include_dirs
if
not
torch
.
cuda
.
is_available
()
and
not
IS_ROCM_PYTORCH
:
# https://github.com/NVIDIA/apex/issues/486
# Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
...
...
@@ -157,9 +156,10 @@ if "--distributed_adam" in sys.argv:
hipcc_args_adam
=
[
'-O3'
]
+
version_dependent_macros
ext_modules
.
append
(
CUDAExtension
(
name
=
'distributed_adam_cuda'
,
sources
=
[
'./apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp'
,
'./apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'
],
include_dirs
=
include_dirs
+
[
this_dir
+
'/apex/contrib/csrc/optimizers/'
],
sources
=
[
'apex/contrib/csrc/optimizers/multi_tensor_distopt_adam.cpp'
,
'apex/contrib/csrc/optimizers/multi_tensor_distopt_adam_kernel.cu'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
),
os
.
path
.
join
(
this_dir
,
'apex/contrib/csrc/optimizers'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
,
'nvcc'
:
nvcc_args_adam
if
not
IS_ROCM_PYTORCH
else
hipcc_args_adam
}))
...
...
@@ -280,9 +280,10 @@ if "--xentropy" in sys.argv:
print
(
"INFO: Building the xentropy extension."
)
ext_modules
.
append
(
CUDAExtension
(
name
=
'xentropy_cuda'
,
sources
=
[
'./apex/contrib/csrc/xentropy/interface.cpp'
,
'./apex/contrib/csrc/xentropy/xentropy_kernel.cu'
],
include_dirs
=
include_dirs
+
[
this_dir
+
'/apex/contrib/csrc/xentropy/'
],
sources
=
[
'apex/contrib/csrc/xentropy/interface.cpp'
,
'apex/contrib/csrc/xentropy/xentropy_kernel.cu'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
),
os
.
path
.
join
(
this_dir
,
'apex/contrib/csrc/xentropy'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'nvcc'
:[
'-O3'
]
+
version_dependent_macros
}))
...
...
@@ -302,9 +303,10 @@ if "--deprecated_fused_adam" in sys.argv:
hipcc_args_fused_adam
=
[
'-O3'
]
+
version_dependent_macros
ext_modules
.
append
(
CUDAExtension
(
name
=
'fused_adam_cuda'
,
sources
=
[
'./apex/contrib/csrc/optimizers/fused_adam_cuda.cpp'
,
'./apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'
],
include_dirs
=
include_dirs
+
[
this_dir
+
'/apex/contrib/csrc/optimizers/'
],
sources
=
[
'apex/contrib/csrc/optimizers/fused_adam_cuda.cpp'
,
'apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
),
os
.
path
.
join
(
this_dir
,
'apex/contrib/csrc/optimizers'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'nvcc'
:
nvcc_args_fused_adam
if
not
IS_ROCM_PYTORCH
else
hipcc_args_fused_adam
}))
...
...
@@ -363,7 +365,7 @@ if "--fast_layer_norm" in sys.argv:
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'-I
./
apex/contrib/csrc/layer_norm
/
'
,
'-Iapex/contrib/csrc/layer_norm'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
...
...
@@ -387,99 +389,84 @@ if "--fast_multihead_attn" in sys.argv:
cc_flag
.
append
(
'arch=compute_80,code=sm_80'
)
subprocess
.
run
([
"git"
,
"submodule"
,
"update"
,
"--init"
,
"apex/contrib/csrc/multihead_attn/cutlass"
])
nvcc_args_mha
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-I./apex/contrib/csrc/multihead_attn/cutlass/'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
hipcc_args_mha
=
[
'-O3'
,
'-I./apex/contrib/csrc/multihead_attn/cutlass/'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
]
+
version_dependent_macros
+
generator_flag
nvcc_args_mha
=
[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-Iapex/contrib/csrc/multihead_attn/cutlass'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
hipcc_args_mha
=
[
'-O3'
,
'-Iapex/contrib/csrc/multihead_attn/cutlass'
,
'-I/opt/rocm/include/hiprand'
,
'-I/opt/rocm/include/rocrand'
,
'-U__HIP_NO_HALF_OPERATORS__'
,
'-U__HIP_NO_HALF_CONVERSIONS__'
]
+
version_dependent_macros
+
generator_flag
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_additive_mask_softmax_dropout'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout.cpp'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout
_cpp
.cpp'
,
'apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
),
os
.
path
.
join
(
this_dir
,
'apex/contrib/csrc/multihead_attn'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
+
generator_flag
,
'nvcc'
:[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-I./apex/contrib/csrc/multihead_attn/cutlass/'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
'nvcc'
:
nvcc_args_mha
if
not
IS_ROCM_PYTORCH
else
hipcc_args_mha
}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_mask_softmax_dropout'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/masked_softmax_dropout.cpp'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/masked_softmax_dropout
_cpp
.cpp'
,
'apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
),
os
.
path
.
join
(
this_dir
,
'apex/contrib/csrc/multihead_attn'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
+
generator_flag
,
'nvcc'
:[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-I./apex/contrib/csrc/multihead_attn/cutlass/'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
'nvcc'
:
nvcc_args_mha
if
not
IS_ROCM_PYTORCH
else
hipcc_args_mha
}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_self_multihead_attn_bias_additive_mask'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask
_cpp
.cpp'
,
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
),
os
.
path
.
join
(
this_dir
,
'apex/contrib/csrc/multihead_attn'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
+
generator_flag
,
'nvcc'
:[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-I./apex/contrib/csrc/multihead_attn/cutlass/'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
'nvcc'
:
nvcc_args_mha
if
not
IS_ROCM_PYTORCH
else
hipcc_args_mha
}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_self_multihead_attn_bias'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias.cpp'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias
_cpp
.cpp'
,
'apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
),
os
.
path
.
join
(
this_dir
,
'apex/contrib/csrc/multihead_attn'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
+
generator_flag
,
'nvcc'
:[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-I./apex/contrib/csrc/multihead_attn/cutlass/'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
'nvcc'
:
nvcc_args_mha
if
not
IS_ROCM_PYTORCH
else
hipcc_args_mha
}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_self_multihead_attn'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn.cpp'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn
_cpp
.cpp'
,
'apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
),
os
.
path
.
join
(
this_dir
,
'apex/contrib/csrc/multihead_attn'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
+
generator_flag
,
'nvcc'
:[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-I./apex/contrib/csrc/multihead_attn/cutlass/'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
'nvcc'
:
nvcc_args_mha
if
not
IS_ROCM_PYTORCH
else
hipcc_args_mha
}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_self_multihead_attn_norm_add'
,
sources
=
[
'./apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add.cpp'
,
'./apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'
],
include_dirs
=
include_dirs
+
[
this_dir
+
'/apex/contrib/csrc/multihead_attn/'
],
sources
=
[
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cpp.cpp'
,
'apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
),
os
.
path
.
join
(
this_dir
,
'apex/contrib/csrc/multihead_attn'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
+
generator_flag
,
'nvcc'
:
nvcc_args_mha
if
not
IS_ROCM_PYTORCH
else
hipcc_args_mha
}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_encdec_multihead_attn'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn.cpp'
,
sources
=
[
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn
_cpp
.cpp'
,
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
),
os
.
path
.
join
(
this_dir
,
'apex/contrib/csrc/multihead_attn'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
+
generator_flag
,
'nvcc'
:[
'-O3'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'-I./apex/contrib/csrc/multihead_attn/cutlass/'
,
'-U__CUDA_NO_HALF_OPERATORS__'
,
'-U__CUDA_NO_HALF_CONVERSIONS__'
,
'--expt-relaxed-constexpr'
,
'--expt-extended-lambda'
,
'--use_fast_math'
]
+
version_dependent_macros
+
generator_flag
+
cc_flag
}))
'nvcc'
:
nvcc_args_mha
if
not
IS_ROCM_PYTORCH
else
hipcc_args_mha
}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fast_encdec_multihead_attn_norm_add'
,
sources
=
[
'./apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add.cpp'
,
'./apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'
],
include_dirs
=
include_dirs
+
[
this_dir
+
'/apex/contrib/csrc/multihead_attn/'
],
sources
=
[
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cpp.cpp'
,
'apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
),
os
.
path
.
join
(
this_dir
,
'apex/contrib/csrc/multihead_attn'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
+
generator_flag
,
'nvcc'
:
nvcc_args_mha
if
not
IS_ROCM_PYTORCH
else
hipcc_args_mha
}))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment