Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
40e15362
Commit
40e15362
authored
Aug 02, 2022
by
hanbao
Committed by
hubertlu-tw
Aug 23, 2022
Browse files
add customized fused op index mulitiplication (#1438)
Co-authored-by:
Han Bao
<
hbao@nvidia.com
>
parent
96850dfa
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
886 additions
and
0 deletions
+886
-0
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp
+139
-0
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu
+479
-0
apex/contrib/index_mul_2d/__init__.py
apex/contrib/index_mul_2d/__init__.py
+1
-0
apex/contrib/index_mul_2d/index_mul_2d.py
apex/contrib/index_mul_2d/index_mul_2d.py
+144
-0
apex/contrib/test/index_mul_2d/test_index_mul_2d.py
apex/contrib/test/index_mul_2d/test_index_mul_2d.py
+106
-0
setup.py
setup.py
+17
-0
No files found.
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp
0 → 100644
View file @
40e15362
#include <torch/torch.h>
#include <vector>
#include <cstdint>
void
index_mul_2d_float_foward_cuda
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
);
void
index_mul_2d_float_backward_cuda
(
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
);
void
index_mul_2d_float_backward_backward_cuda
(
at
::
Tensor
&
grad_grad_out
,
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
grad_grad_in1
,
const
at
::
Tensor
&
grad_grad_in2
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
);
void
index_mul_2d_half_foward_cuda
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
);
void
index_mul_2d_half_backward_cuda
(
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
);
void
index_mul_2d_half_backward_backward_cuda
(
at
::
Tensor
&
grad_grad_out
,
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
grad_grad_in1
,
const
at
::
Tensor
&
grad_grad_in2
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
);
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
void
index_mul_2d_float_forward
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
return
index_mul_2d_float_foward_cuda
(
out
,
in1
,
in2
,
idx1
);
}
void
index_mul_2d_float_backward
(
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
return
index_mul_2d_float_backward_cuda
(
grad_in1
,
grad_in2
,
grad_out
,
in1
,
in2
,
idx1
);
}
void
index_mul_2d_float_backwrad_backward
(
at
::
Tensor
&
grad_grad_out
,
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
grad_grad_in1
,
const
at
::
Tensor
&
grad_grad_in2
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
return
index_mul_2d_float_backward_backward_cuda
(
grad_grad_out
,
grad_in1
,
grad_in2
,
grad_out
,
grad_grad_in1
,
grad_grad_in2
,
in1
,
in2
,
idx1
);
}
void
index_mul_2d_half_forward
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
return
index_mul_2d_half_foward_cuda
(
out
,
in1
,
in2
,
idx1
);
}
void
index_mul_2d_half_backward
(
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
return
index_mul_2d_half_backward_cuda
(
grad_in1
,
grad_in2
,
grad_out
,
in1
,
in2
,
idx1
);
}
void
index_mul_2d_half_backwrad_backward
(
at
::
Tensor
&
grad_grad_out
,
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
grad_grad_in1
,
const
at
::
Tensor
&
grad_grad_in2
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
return
index_mul_2d_half_backward_backward_cuda
(
grad_grad_out
,
grad_in1
,
grad_in2
,
grad_out
,
grad_grad_in1
,
grad_grad_in2
,
in1
,
in2
,
idx1
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"float_forward"
,
&
index_mul_2d_float_forward
,
"index mul float calculation forward (CUDA)"
);
m
.
def
(
"float_backward"
,
&
index_mul_2d_float_backward
,
"index mul float calculation backward (CUDA)"
);
m
.
def
(
"float_backward_backward"
,
&
index_mul_2d_float_backwrad_backward
,
"index mul float calculation backward backward (CUDA)"
);
m
.
def
(
"half_forward"
,
&
index_mul_2d_half_forward
,
"index mul half calculation forward (CUDA)"
);
m
.
def
(
"half_backward"
,
&
index_mul_2d_half_backward
,
"index mul half calculation backward (CUDA)"
);
m
.
def
(
"half_backward_backward"
,
&
index_mul_2d_half_backwrad_backward
,
"index mul half calculation backward backward (CUDA)"
);
}
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu
0 → 100644
View file @
40e15362
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Atomic.cuh>
__global__
void
index_mul_2d_float_dim64
(
float
*
out
,
const
float
*
in1
,
const
float
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
constexpr
int
fea_dim
=
64
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
(
idx1
[
start_idx
]
*
fea_dim
)
/
4
+
tidx
;
int64_t
vec_idx2
=
(
start_idx
*
fea_dim
)
/
4
+
tidx
;
float4
res
,
src1
,
src2
;
src1
=
reinterpret_cast
<
const
float4
*>
(
in1
)[
vec_idx1
];
src2
=
reinterpret_cast
<
const
float4
*>
(
in2
)[
vec_idx2
];
res
.
x
=
src1
.
x
*
src2
.
x
;
res
.
y
=
src1
.
y
*
src2
.
y
;
res
.
z
=
src1
.
z
*
src2
.
z
;
res
.
w
=
src1
.
w
*
src2
.
w
;
reinterpret_cast
<
float4
*>
(
out
)[
vec_idx2
]
=
res
;
}
}
__global__
void
index_mul_2d_float
(
float
*
out
,
const
float
*
in1
,
const
float
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
,
const
int64_t
fea_dim
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
const
int
stride
=
blockDim
.
x
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
(
idx1
[
start_idx
]
*
fea_dim
);
int64_t
vec_idx2
=
(
start_idx
*
fea_dim
);
for
(
int
i
=
tidx
;
i
<
fea_dim
;
i
+=
stride
)
{
out
[
vec_idx2
+
i
]
=
in1
[
vec_idx1
+
i
]
*
in2
[
vec_idx2
+
i
];
}
}
}
__global__
void
index_mul_2d_half
(
at
::
Half
*
out
,
const
at
::
Half
*
in1
,
const
at
::
Half
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
,
const
int64_t
fea_dim
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
const
int
stride
=
blockDim
.
x
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
(
idx1
[
start_idx
]
*
fea_dim
);
int64_t
vec_idx2
=
(
start_idx
*
fea_dim
);
for
(
int
i
=
tidx
;
i
<
fea_dim
;
i
+=
stride
)
{
out
[
vec_idx2
+
i
]
=
at
::
Half
(
static_cast
<
float
>
(
in1
[
vec_idx1
+
i
])
*
static_cast
<
float
>
(
in2
[
vec_idx2
+
i
]));
}
}
}
__global__
void
index_mul_2d_grad_float_dim64
(
float
*
grad_in1
,
float
*
grad_in2
,
const
float
*
grad_out
,
const
float
*
in1
,
const
float
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
constexpr
int
fea_dim
=
64
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
(
idx1
[
start_idx
]
*
fea_dim
)
/
4
+
tidx
;
int64_t
vec_idx2
=
(
start_idx
*
fea_dim
)
/
4
+
tidx
;
float4
src_in1
,
src_in2
,
src_grad_out
,
dst_grad_in2
;
src_grad_out
=
reinterpret_cast
<
const
float4
*>
(
grad_out
)[
vec_idx2
];
src_in1
=
reinterpret_cast
<
const
float4
*>
(
in1
)[
vec_idx1
];
src_in2
=
reinterpret_cast
<
const
float4
*>
(
in2
)[
vec_idx2
];
int64_t
grad_in1_base_idx
=
idx1
[
start_idx
]
*
fea_dim
+
tidx
*
4
;
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
0
,
src_grad_out
.
x
*
src_in2
.
x
);
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
1
,
src_grad_out
.
y
*
src_in2
.
y
);
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
2
,
src_grad_out
.
z
*
src_in2
.
z
);
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
3
,
src_grad_out
.
w
*
src_in2
.
w
);
dst_grad_in2
.
x
=
src_grad_out
.
x
*
src_in1
.
x
;
dst_grad_in2
.
y
=
src_grad_out
.
y
*
src_in1
.
y
;
dst_grad_in2
.
z
=
src_grad_out
.
z
*
src_in1
.
z
;
dst_grad_in2
.
w
=
src_grad_out
.
w
*
src_in1
.
w
;
reinterpret_cast
<
float4
*>
(
grad_in2
)[
vec_idx2
]
=
dst_grad_in2
;
}
}
__global__
void
index_mul_2d_grad_float
(
float
*
grad_in1
,
float
*
grad_in2
,
const
float
*
grad_out
,
const
float
*
in1
,
const
float
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
,
const
int64_t
fea_dim
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
const
int
stride
=
blockDim
.
x
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
idx1
[
start_idx
]
*
fea_dim
;
int64_t
vec_idx2
=
start_idx
*
fea_dim
;
for
(
int
i
=
tidx
;
i
<
fea_dim
;
i
+=
stride
)
{
float
src_in1
=
in1
[
vec_idx1
+
i
];
float
src_in2
=
in2
[
vec_idx2
+
i
];
float
src_grad_out
=
grad_out
[
vec_idx2
+
i
];
grad_in2
[
vec_idx2
+
i
]
=
src_grad_out
*
src_in1
;
gpuAtomicAdd
(
grad_in1
+
vec_idx1
+
i
,
src_grad_out
*
src_in2
);
}
}
}
__global__
void
index_mul_2d_grad_half
(
at
::
Half
*
grad_in1
,
at
::
Half
*
grad_in2
,
const
at
::
Half
*
grad_out
,
const
at
::
Half
*
in1
,
const
at
::
Half
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
,
const
int64_t
fea_dim
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
const
int
stride
=
blockDim
.
x
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
idx1
[
start_idx
]
*
fea_dim
;
int64_t
vec_idx2
=
start_idx
*
fea_dim
;
for
(
int
i
=
tidx
;
i
<
fea_dim
;
i
+=
stride
)
{
float
src_in1
=
static_cast
<
float
>
(
in1
[
vec_idx1
+
i
]);
float
src_in2
=
static_cast
<
float
>
(
in2
[
vec_idx2
+
i
]);
float
src_grad_out
=
static_cast
<
float
>
(
grad_out
[
vec_idx2
+
i
]);
grad_in2
[
vec_idx2
+
i
]
=
at
::
Half
(
src_grad_out
*
src_in1
);
gpuAtomicAdd
(
grad_in1
+
vec_idx1
+
i
,
at
::
Half
(
src_grad_out
*
src_in2
));
}
}
}
__global__
void
index_mul_2d_grad_grad_float_dim64
(
float
*
grad_grad_out
,
float
*
grad_in1
,
float
*
grad_in2
,
const
float
*
grad_out
,
const
float
*
grad_grad_in1
,
const
float
*
grad_grad_in2
,
const
float
*
in1
,
const
float
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
constexpr
int
fea_dim
=
64
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
(
idx1
[
start_idx
]
*
fea_dim
)
/
4
+
tidx
;
int64_t
vec_idx2
=
(
start_idx
*
fea_dim
)
/
4
+
tidx
;
float4
src_grad_grad_in1
,
src_in1
,
src_grad_grad_in2
,
src_in2
,
src_grad_out
;
float4
dst_grad_grad_out
,
dst_grad_in2
;
src_grad_grad_in1
=
reinterpret_cast
<
const
float4
*>
(
grad_grad_in1
)[
vec_idx1
];
src_in1
=
reinterpret_cast
<
const
float4
*>
(
in1
)[
vec_idx1
];
src_grad_grad_in2
=
reinterpret_cast
<
const
float4
*>
(
grad_grad_in2
)[
vec_idx2
];
src_in2
=
reinterpret_cast
<
const
float4
*>
(
in2
)[
vec_idx2
];
dst_grad_grad_out
.
x
=
src_grad_grad_in1
.
x
*
src_in2
.
x
+
src_grad_grad_in2
.
x
*
src_in1
.
x
;
dst_grad_grad_out
.
y
=
src_grad_grad_in1
.
y
*
src_in2
.
y
+
src_grad_grad_in2
.
y
*
src_in1
.
y
;
dst_grad_grad_out
.
z
=
src_grad_grad_in1
.
z
*
src_in2
.
z
+
src_grad_grad_in2
.
z
*
src_in1
.
z
;
dst_grad_grad_out
.
w
=
src_grad_grad_in1
.
w
*
src_in2
.
w
+
src_grad_grad_in2
.
w
*
src_in1
.
w
;
reinterpret_cast
<
float4
*>
(
grad_grad_out
)[
vec_idx2
]
=
dst_grad_grad_out
;
src_grad_out
=
reinterpret_cast
<
const
float4
*>
(
grad_out
)[
vec_idx2
];
int64_t
grad_in1_base_idx
=
idx1
[
start_idx
]
*
fea_dim
+
tidx
*
4
;
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
0
,
src_grad_grad_in2
.
x
*
src_grad_out
.
x
);
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
1
,
src_grad_grad_in2
.
y
*
src_grad_out
.
y
);
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
2
,
src_grad_grad_in2
.
z
*
src_grad_out
.
z
);
gpuAtomicAdd
(
grad_in1
+
grad_in1_base_idx
+
3
,
src_grad_grad_in2
.
w
*
src_grad_out
.
w
);
dst_grad_in2
.
x
=
src_grad_grad_in1
.
x
*
src_grad_out
.
x
;
dst_grad_in2
.
y
=
src_grad_grad_in1
.
y
*
src_grad_out
.
y
;
dst_grad_in2
.
z
=
src_grad_grad_in1
.
z
*
src_grad_out
.
z
;
dst_grad_in2
.
w
=
src_grad_grad_in1
.
w
*
src_grad_out
.
w
;
reinterpret_cast
<
float4
*>
(
grad_in2
)[
vec_idx2
]
=
dst_grad_in2
;
}
}
__global__
void
index_mul_2d_grad_grad_float
(
float
*
grad_grad_out
,
float
*
grad_in1
,
float
*
grad_in2
,
const
float
*
grad_out
,
const
float
*
grad_grad_in1
,
const
float
*
grad_grad_in2
,
const
float
*
in1
,
const
float
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
,
const
int64_t
fea_dim
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
const
int
stride
=
blockDim
.
x
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
idx1
[
start_idx
]
*
fea_dim
;
int64_t
vec_idx2
=
start_idx
*
fea_dim
;
for
(
int
i
=
tidx
;
i
<
fea_dim
;
i
+=
stride
)
{
float
src_grad_grad_in1
=
grad_grad_in1
[
vec_idx1
+
i
];
float
src_grad_grad_in2
=
grad_grad_in2
[
vec_idx2
+
i
];
float
src_in1
=
in1
[
vec_idx1
+
i
];
float
src_in2
=
in2
[
vec_idx2
+
i
];
float
src_grad_out
=
grad_out
[
vec_idx2
+
i
];
grad_grad_out
[
vec_idx2
+
i
]
=
src_grad_grad_in1
*
src_in2
+
src_grad_grad_in2
*
src_in1
;
grad_in2
[
vec_idx2
+
i
]
=
src_grad_grad_in1
*
src_grad_out
;
gpuAtomicAdd
(
grad_in1
+
vec_idx1
+
i
,
src_grad_grad_in2
*
src_grad_out
);
}
}
}
__global__
void
index_mul_2d_grad_grad_half
(
at
::
Half
*
grad_grad_out
,
at
::
Half
*
grad_in1
,
at
::
Half
*
grad_in2
,
const
at
::
Half
*
grad_out
,
const
at
::
Half
*
grad_grad_in1
,
const
at
::
Half
*
grad_grad_in2
,
const
at
::
Half
*
in1
,
const
at
::
Half
*
in2
,
const
int64_t
*
idx1
,
const
int64_t
size
,
const
int64_t
fea_dim
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
tidy
=
threadIdx
.
y
;
const
int
bidx
=
blockIdx
.
x
;
const
int
start_idx
=
bidx
*
blockDim
.
y
+
tidy
;
const
int
stride
=
blockDim
.
x
;
if
(
start_idx
<
size
)
{
int64_t
vec_idx1
=
idx1
[
start_idx
]
*
fea_dim
;
int64_t
vec_idx2
=
start_idx
*
fea_dim
;
for
(
int
i
=
tidx
;
i
<
fea_dim
;
i
+=
stride
)
{
float
src_grad_grad_in1
=
static_cast
<
float
>
(
grad_grad_in1
[
vec_idx1
+
i
]);
float
src_grad_grad_in2
=
static_cast
<
float
>
(
grad_grad_in2
[
vec_idx2
+
i
]);
float
src_in1
=
static_cast
<
float
>
(
in1
[
vec_idx1
+
i
]);
float
src_in2
=
static_cast
<
float
>
(
in2
[
vec_idx2
+
i
]);
float
src_grad_out
=
static_cast
<
float
>
(
grad_out
[
vec_idx2
+
i
]);
grad_grad_out
[
vec_idx2
+
i
]
=
at
::
Half
(
src_grad_grad_in1
*
src_in2
+
src_grad_grad_in2
*
src_in1
);
grad_in2
[
vec_idx2
+
i
]
=
at
::
Half
(
src_grad_grad_in1
*
src_grad_out
);
gpuAtomicAdd
(
grad_in1
+
vec_idx1
+
i
,
at
::
Half
(
src_grad_grad_in2
*
src_grad_out
));
}
}
}
void
index_mul_2d_float_foward_cuda
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
const
int64_t
size
=
in2
.
size
(
0
);
const
int64_t
fea_dim
=
in2
.
size
(
1
);
if
(
size
<
0
){
return
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
fea_dim
==
64
)
{
const
int
BLOCK_THREADS_DIMX
=
16
;
const
int
BLOCK_THREADS_DIMY
=
16
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
index_mul_2d_float_dim64
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
},
0
,
stream
>>>
(
out
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
);
}
else
{
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
index_mul_2d_float
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
},
0
,
stream
>>>
(
out
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
}
AT_CUDA_CHECK
(
cudaGetLastError
());
}
void
index_mul_2d_float_backward_cuda
(
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
const
int64_t
size
=
in2
.
size
(
0
);
const
int64_t
fea_dim
=
in2
.
size
(
1
);
if
(
size
<
0
){
return
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
fea_dim
==
64
)
{
const
int
BLOCK_THREADS_DIMX
=
16
;
const
int
BLOCK_THREADS_DIMY
=
16
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
index_mul_2d_grad_float_dim64
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
},
0
,
stream
>>>
(
grad_in1
.
data_ptr
<
float
>
(),
grad_in2
.
data_ptr
<
float
>
(),
grad_out
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
);
AT_CUDA_CHECK
(
cudaGetLastError
());
}
else
{
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
index_mul_2d_grad_float
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
},
0
,
stream
>>>
(
grad_in1
.
data_ptr
<
float
>
(),
grad_in2
.
data_ptr
<
float
>
(),
grad_out
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
}
}
void
index_mul_2d_float_backward_backward_cuda
(
at
::
Tensor
&
grad_grad_out
,
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
grad_grad_in1
,
const
at
::
Tensor
&
grad_grad_in2
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
const
int64_t
size
=
in2
.
size
(
0
);
const
int64_t
fea_dim
=
in2
.
size
(
1
);
if
(
size
<
0
){
return
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
fea_dim
==
64
)
{
const
int
BLOCK_THREADS_DIMX
=
16
;
const
int
BLOCK_THREADS_DIMY
=
16
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
index_mul_2d_grad_grad_float_dim64
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
},
0
,
stream
>>>
(
grad_grad_out
.
data_ptr
<
float
>
(),
grad_in1
.
data_ptr
<
float
>
(),
grad_in2
.
data_ptr
<
float
>
(),
grad_out
.
data_ptr
<
float
>
(),
grad_grad_in1
.
data_ptr
<
float
>
(),
grad_grad_in2
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
);
}
else
{
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
index_mul_2d_grad_grad_float
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
},
0
,
stream
>>>
(
grad_grad_out
.
data_ptr
<
float
>
(),
grad_in1
.
data_ptr
<
float
>
(),
grad_in2
.
data_ptr
<
float
>
(),
grad_out
.
data_ptr
<
float
>
(),
grad_grad_in1
.
data_ptr
<
float
>
(),
grad_grad_in2
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
}
AT_CUDA_CHECK
(
cudaGetLastError
());
}
void
index_mul_2d_half_foward_cuda
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
const
int64_t
size
=
in2
.
size
(
0
);
const
int64_t
fea_dim
=
in2
.
size
(
1
);
if
(
size
<
0
){
return
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
index_mul_2d_half
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
},
0
,
stream
>>>
(
out
.
data_ptr
<
at
::
Half
>
(),
in1
.
data_ptr
<
at
::
Half
>
(),
in2
.
data_ptr
<
at
::
Half
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
AT_CUDA_CHECK
(
cudaGetLastError
());
}
void
index_mul_2d_half_backward_cuda
(
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
const
int64_t
size
=
in2
.
size
(
0
);
const
int64_t
fea_dim
=
in2
.
size
(
1
);
if
(
size
<
0
){
return
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
index_mul_2d_grad_half
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
},
0
,
stream
>>>
(
grad_in1
.
data_ptr
<
at
::
Half
>
(),
grad_in2
.
data_ptr
<
at
::
Half
>
(),
grad_out
.
data_ptr
<
at
::
Half
>
(),
in1
.
data_ptr
<
at
::
Half
>
(),
in2
.
data_ptr
<
at
::
Half
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
}
void
index_mul_2d_half_backward_backward_cuda
(
at
::
Tensor
&
grad_grad_out
,
at
::
Tensor
&
grad_in1
,
at
::
Tensor
&
grad_in2
,
const
at
::
Tensor
&
grad_out
,
const
at
::
Tensor
&
grad_grad_in1
,
const
at
::
Tensor
&
grad_grad_in2
,
const
at
::
Tensor
&
in1
,
const
at
::
Tensor
&
in2
,
const
at
::
Tensor
&
idx1
)
{
const
int64_t
size
=
in2
.
size
(
0
);
const
int64_t
fea_dim
=
in2
.
size
(
1
);
if
(
size
<
0
){
return
;
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
index_mul_2d_grad_grad_half
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
},
0
,
stream
>>>
(
grad_grad_out
.
data_ptr
<
at
::
Half
>
(),
grad_in1
.
data_ptr
<
at
::
Half
>
(),
grad_in2
.
data_ptr
<
at
::
Half
>
(),
grad_out
.
data_ptr
<
at
::
Half
>
(),
grad_grad_in1
.
data_ptr
<
at
::
Half
>
(),
grad_grad_in2
.
data_ptr
<
at
::
Half
>
(),
in1
.
data_ptr
<
at
::
Half
>
(),
in2
.
data_ptr
<
at
::
Half
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
AT_CUDA_CHECK
(
cudaGetLastError
());
}
\ No newline at end of file
apex/contrib/index_mul_2d/__init__.py
0 → 100644
View file @
40e15362
from
.index_mul_2d
import
index_mul_2d
apex/contrib/index_mul_2d/index_mul_2d.py
0 → 100644
View file @
40e15362
import
torch
import
fused_index_mul_2d
class
IndexMul2d_
(
torch
.
autograd
.
Function
):
'''
Currently only support index in dimension 0 with a 2-dimension tensor.
The shape of indexed in1 must be same with in2. Now this kernel does not support broadcast.
The datatype must be float32 or float16.
'''
@
staticmethod
def
forward
(
ctx
,
in1
:
torch
.
Tensor
,
in2
:
torch
.
Tensor
,
idx1
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
in2
.
size
(
0
)
==
idx1
.
size
(
0
)
if
((
in1
.
dtype
!=
torch
.
float32
and
in1
.
dtype
!=
torch
.
half
)
or
in2
.
dtype
!=
in1
.
dtype
):
raise
RuntimeError
(
"input1'dtype and input2's dtype must be fp32 or fp16. And input type must be same"
)
if
(
in1
.
dim
()
!=
2
or
in2
.
dim
()
!=
2
):
raise
RuntimeError
(
"in1 and in2 must be 2-dimension tensor."
)
if
(
idx1
.
dim
()
!=
1
):
raise
RuntimeError
(
"idx1 must be 1-dimension tensor."
)
if
not
in1
.
is_contiguous
():
in1
=
in1
.
contiguous
()
if
not
in2
.
is_contiguous
():
in2
=
in2
.
contiguous
()
if
not
idx1
.
is_contiguous
():
idx1
=
idx1
.
contiguous
()
assert
in1
.
is_contiguous
()
assert
in2
.
is_contiguous
()
assert
idx1
.
is_contiguous
()
out
=
torch
.
empty_like
(
in2
)
if
(
in1
.
dtype
==
torch
.
float32
):
fused_index_mul_2d
.
float_forward
(
out
,
in1
,
in2
,
idx1
)
elif
(
in1
.
dtype
==
torch
.
half
):
fused_index_mul_2d
.
half_forward
(
out
,
in1
,
in2
,
idx1
)
ctx
.
for_backwards
=
(
in1
,
in2
,
idx1
)
return
out
@
staticmethod
def
backward
(
ctx
,
grad_out
):
in1
,
in2
,
idx1
=
ctx
.
for_backwards
grad_in1
,
grad_in2
=
index_mul_2d_backward
(
in1
,
in2
,
idx1
,
grad_out
)
return
grad_in1
,
grad_in2
,
None
class
IndexMul2dBackward_
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
in1
:
torch
.
Tensor
,
in2
:
torch
.
Tensor
,
idx1
:
torch
.
Tensor
,
grad_out
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
not
in1
.
is_contiguous
():
in1
=
in1
.
contiguous
()
if
not
in2
.
is_contiguous
():
in2
=
in2
.
contiguous
()
if
not
idx1
.
is_contiguous
():
idx1
=
idx1
.
contiguous
()
if
not
grad_out
.
is_contiguous
():
grad_out
=
grad_out
.
contiguous
()
assert
in1
.
is_contiguous
()
assert
in2
.
is_contiguous
()
assert
idx1
.
is_contiguous
()
assert
grad_out
.
is_contiguous
()
grad_in1
=
torch
.
zeros_like
(
in1
)
grad_in2
=
torch
.
empty_like
(
in2
)
if
(
in1
.
dtype
==
torch
.
float32
):
fused_index_mul_2d
.
float_backward
(
grad_in1
,
grad_in2
,
grad_out
,
in1
,
in2
,
idx1
)
elif
(
in1
.
dtype
==
torch
.
half
):
fused_index_mul_2d
.
half_backward
(
grad_in1
,
grad_in2
,
grad_out
,
in1
,
in2
,
idx1
)
ctx
.
for_backwards
=
(
in1
,
in2
,
idx1
,
grad_out
)
return
grad_in1
,
grad_in2
@
staticmethod
def
backward
(
ctx
,
grad_grad_in1
,
grad_grad_in2
):
if
not
grad_grad_in1
.
is_contiguous
():
grad_grad_in1
=
grad_grad_in1
.
contiguous
()
if
not
grad_grad_in2
.
is_contiguous
():
grad_grad_in2
=
grad_grad_in2
.
contiguous
()
assert
grad_grad_in1
.
is_contiguous
()
assert
grad_grad_in2
.
is_contiguous
()
in1
,
in2
,
idx1
,
grad_out
=
ctx
.
for_backwards
grad_in1
=
torch
.
zeros_like
(
in1
)
grad_in2
=
torch
.
empty_like
(
in2
)
grad_grad_out
=
torch
.
empty_like
(
grad_out
)
if
(
in1
.
dtype
==
torch
.
float32
):
fused_index_mul_2d
.
float_backward_backward
(
grad_grad_out
,
grad_in1
,
grad_in2
,
grad_out
,
grad_grad_in1
,
grad_grad_in2
,
in1
,
in2
,
idx1
)
elif
(
in1
.
dtype
==
torch
.
half
):
fused_index_mul_2d
.
half_backward_backward
(
grad_grad_out
,
grad_in1
,
grad_in2
,
grad_out
,
grad_grad_in1
,
grad_grad_in2
,
in1
,
in2
,
idx1
)
return
grad_in1
,
grad_in2
,
None
,
grad_grad_out
index_mul_2d
=
IndexMul2d_
.
apply
index_mul_2d_backward
=
IndexMul2dBackward_
.
apply
apex/contrib/test/index_mul_2d/test_index_mul_2d.py
0 → 100644
View file @
40e15362
import
random
import
unittest
import
torch
import
torch.nn.functional
as
F
HAS_INDEX_MUL_2D_RELU
=
None
try
:
from
apex.contrib.index_mul_2d
import
index_mul_2d
except
ImportError
as
e
:
HAS_INDEX_MUL_2D_RELU
=
False
else
:
HAS_INDEX_MUL_2D_RELU
=
True
@
unittest
.
skipIf
(
not
HAS_INDEX_MUL_2D_RELU
,
"`apex.contrib.index_mul_2d` is not found."
)
class
IndexMul2dTest
(
unittest
.
TestCase
):
def
setUp
(
self
,
seed
=
0
):
torch
.
manual_seed
(
seed
)
self
.
input1_size
=
random
.
randint
(
1
,
1000
)
self
.
input2_size
=
random
.
randint
(
1
,
100000
)
self
.
feature_size
=
random
.
randint
(
1
,
256
)
self
.
input1_float
=
torch
.
randn
(
size
=
(
self
.
input1_size
,
self
.
feature_size
),).
cuda
()
self
.
input2_float
=
torch
.
randn
(
size
=
(
self
.
input2_size
,
self
.
feature_size
),).
cuda
()
self
.
index1
=
torch
.
randint
(
low
=
0
,
high
=
self
.
input1_size
,
size
=
(
self
.
input2_size
,)).
cuda
()
self
.
input1_float_
=
self
.
input1_float
.
clone
()
self
.
input2_float_
=
self
.
input2_float
.
clone
()
self
.
input1_float
.
requires_grad_
()
self
.
input1_float_
.
requires_grad_
()
self
.
input2_float
.
requires_grad_
()
self
.
input2_float_
.
requires_grad_
()
self
.
input1_half
=
torch
.
randn
(
size
=
(
self
.
input1_size
,
self
.
feature_size
),).
cuda
().
half
()
self
.
input2_half
=
torch
.
randn
(
size
=
(
self
.
input2_size
,
self
.
feature_size
),).
cuda
().
half
()
self
.
input1_half_
=
self
.
input1_half
.
clone
()
self
.
input2_half_
=
self
.
input2_half
.
clone
()
self
.
input1_half
.
requires_grad_
()
self
.
input2_half
.
requires_grad_
()
self
.
input1_half_
.
requires_grad_
()
self
.
input2_half_
.
requires_grad_
()
def
test_index_mul_float
(
self
):
out
=
index_mul_2d
(
self
.
input1_float
,
self
.
input2_float
,
self
.
index1
)
energy
=
(
out
.
float
()
**
2
).
sum
()
/
out
.
numel
()
force
=
torch
.
autograd
.
grad
(
energy
,
self
.
input1_float
,
grad_outputs
=
torch
.
ones_like
(
energy
),
create_graph
=
True
,
)[
0
]
loss
=
(
out
.
float
()
**
2
).
sum
()
/
out
.
numel
()
+
(
force
.
float
()
**
2
).
sum
()
loss
.
backward
()
out_
=
self
.
input1_float_
[
self
.
index1
]
*
self
.
input2_float_
energy_
=
(
out_
.
float
()
**
2
).
sum
()
/
out
.
numel
()
force_
=
torch
.
autograd
.
grad
(
energy_
,
self
.
input1_float_
,
grad_outputs
=
torch
.
ones_like
(
energy
),
create_graph
=
True
,
)[
0
]
loss
=
(
out_
.
float
()
**
2
).
sum
()
/
out_
.
numel
()
+
(
force_
.
float
()
**
2
).
sum
()
loss
.
backward
()
self
.
assertTrue
(
torch
.
allclose
(
self
.
input1_float
,
self
.
input1_float_
,
atol
=
1e-3
,
rtol
=
1e-3
,
equal_nan
=
True
))
self
.
assertTrue
(
torch
.
allclose
(
self
.
input2_float
,
self
.
input2_float_
,
atol
=
1e-3
,
rtol
=
1e-3
,
equal_nan
=
True
))
self
.
assertTrue
(
torch
.
allclose
(
self
.
input1_float
.
grad
,
self
.
input1_float_
.
grad
,
atol
=
1e-3
,
rtol
=
1e-3
,
equal_nan
=
True
))
self
.
assertTrue
(
torch
.
allclose
(
self
.
input2_float
.
grad
,
self
.
input2_float_
.
grad
,
atol
=
1e-3
,
rtol
=
1e-3
,
equal_nan
=
True
))
def
test_index_mul_half
(
self
):
out
=
index_mul_2d
(
self
.
input1_half
,
self
.
input2_half
,
self
.
index1
)
energy
=
(
out
.
float
()
**
2
).
sum
()
/
out
.
numel
()
force
=
torch
.
autograd
.
grad
(
energy
,
self
.
input1_half
,
grad_outputs
=
torch
.
ones_like
(
energy
),
create_graph
=
True
,
)[
0
]
loss
=
(
out
.
float
()
**
2
).
sum
()
/
out
.
numel
()
+
(
force
.
float
()
**
2
).
sum
()
loss
.
backward
()
out_
=
self
.
input1_half_
[
self
.
index1
]
*
self
.
input2_half_
energy_
=
(
out_
.
float
()
**
2
).
sum
()
/
out
.
numel
()
force_
=
torch
.
autograd
.
grad
(
energy_
,
self
.
input1_half_
,
grad_outputs
=
torch
.
ones_like
(
energy
),
create_graph
=
True
,
)[
0
]
loss
=
(
out_
.
float
()
**
2
).
sum
()
/
out_
.
numel
()
+
(
force_
.
float
()
**
2
).
sum
()
loss
.
backward
()
self
.
assertTrue
(
torch
.
allclose
(
self
.
input1_half
,
self
.
input1_half_
,
atol
=
1e-3
,
rtol
=
1e-3
,
equal_nan
=
True
))
self
.
assertTrue
(
torch
.
allclose
(
self
.
input2_half
,
self
.
input2_half_
,
atol
=
1e-3
,
rtol
=
1e-3
,
equal_nan
=
True
))
self
.
assertTrue
(
torch
.
allclose
(
self
.
input1_half
.
grad
,
self
.
input1_half_
.
grad
,
atol
=
1e-3
,
rtol
=
1e-3
,
equal_nan
=
True
))
self
.
assertTrue
(
torch
.
allclose
(
self
.
input2_half
.
grad
,
self
.
input2_half_
.
grad
,
atol
=
1e-3
,
rtol
=
1e-3
,
equal_nan
=
True
))
if
__name__
==
'__main__'
:
unittest
.
main
()
setup.py
View file @
40e15362
...
@@ -307,6 +307,23 @@ if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv:
...
@@ -307,6 +307,23 @@ if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv:
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'nvcc'
:[
'-O3'
]
+
version_dependent_macros
}))
'nvcc'
:[
'-O3'
]
+
version_dependent_macros
}))
if
"--index_mul_2d"
in
sys
.
argv
:
if
"--index_mul_2d"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--index_mul_2d"
)
ext_modules
.
append
(
CUDAExtension
(
name
=
'fused_index_mul_2d'
,
sources
=
[
'apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda.cpp'
,
'apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu'
,
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'nvcc'
:([
'-O3'
,
'--use_fast_math'
,
'--ftz=false'
]
if
not
IS_ROCM_PYTORCH
else
[
'-O3'
])
+
version_dependent_macros
,
},
)
)
if
"--deprecated_fused_adam"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--deprecated_fused_adam"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
from
torch.utils.cpp_extension
import
CUDAExtension
from
torch.utils.cpp_extension
import
CUDAExtension
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment