"docs/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "1209978d97fc0a1bb04182916d6e3332b842ace3"
Commit a4eb97fb authored by Thor Johnsen's avatar Thor Johnsen
Browse files

Bug fixes

parent 40a0e025
...@@ -30,15 +30,23 @@ void deleter(void* ptr) ...@@ -30,15 +30,23 @@ void deleter(void* ptr)
*/ */
template<class T> template<class T>
at::Tensor blob_view(T* raw_ptr, std::vector<int64_t> shape, const at::TensorOptions& options) at::Tensor blob_view(T* raw_ptr, std::vector<int64_t> shape, const at::TensorOptions& options, bool channels_last)
{ {
std::vector<int64_t> strides(shape.size());
size_t size = 1; size_t size = 1;
int idx = strides.size(); std::vector<int64_t> strides(shape.size());
for (auto it = shape.rbegin(); it != shape.rend(); ++it) if (channels_last) {
{ assert(shape.size() == 4);
strides[--idx] = size; strides[0] = shape[1]*shape[2]*shape[3];
size *= *it; strides[1] = 1;
strides[2] = shape[1]*shape[3];
strides[3] = shape[1];
} else {
int idx = strides.size();
for (auto it = shape.rbegin(); it != shape.rend(); ++it)
{
strides[--idx] = size;
size *= *it;
}
} }
size *= sizeof(T); size *= sizeof(T);
// TODO: Implement dynamic reuse of pooled peer memory. // TODO: Implement dynamic reuse of pooled peer memory.
...@@ -139,11 +147,11 @@ __device__ void strided_copy_kernel( ...@@ -139,11 +147,11 @@ __device__ void strided_copy_kernel(
} }
} }
template<bool wait, bool clear>
__device__ void dual_signal_wait_clear( __device__ void dual_signal_wait_clear(
volatile int* signal1_flag, volatile int* wait1_flag, volatile int* signal1_flag, volatile int* wait1_flag,
volatile int* signal2_flag, volatile int* wait2_flag, volatile int* signal2_flag, volatile int* wait2_flag,
const int v1, const int v2, const int v3, const int v4, const int v1, const int v2, const int v3, const int v4
const bool clear
) )
{ {
register int r1, r2, r3, r4, r5, r6, r7, r8; register int r1, r2, r3, r4, r5, r6, r7, r8;
...@@ -152,17 +160,20 @@ __device__ void dual_signal_wait_clear( ...@@ -152,17 +160,20 @@ __device__ void dual_signal_wait_clear(
if (is_main_thread) { if (is_main_thread) {
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal1_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory"); asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(signal2_flag), "r"(v1), "r"(v2), "r"(v3), "r"(v4) : "memory");
do { if (wait) {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait1_flag) : "memory"); do {
asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r5), "=r"(r6), "=r"(r7), "=r"(r8) : "l"(wait2_flag) : "memory"); asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4) : "l"(wait1_flag) : "memory");
} while (r1 != v1 || r5 != v1 || r2 != v2 || r6 != v2 || r3 != v3 || r7 != v3 || r4 != v4 || r8 != v4); asm volatile("ld.volatile.global.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(r5), "=r"(r6), "=r"(r7), "=r"(r8) : "l"(wait2_flag) : "memory");
} while (r1 != v1 || r5 != v1 || r2 != v2 || r6 != v2 || r3 != v3 || r7 != v3 || r4 != v4 || r8 != v4);
}
} }
cg::this_grid().sync(); cg::this_grid().sync();
// optionally clear wait flag if (clear) {
if (clear && is_main_thread) { if (is_main_thread) {
r1 = 0; r2 = 0; r3 = 0; r4 = 0; r1 = 0; r2 = 0; r3 = 0; r4 = 0;
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait1_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory"); asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait1_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait2_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory"); asm volatile("st.volatile.global.v4.u32 [%0], {%1,%2,%3,%4};" :: "l"(wait2_flag), "r"(r1), "r"(r2), "r"(r3), "r"(r4) : "memory");
}
} }
} }
...@@ -173,12 +184,14 @@ __launch_bounds__(128, 16) ...@@ -173,12 +184,14 @@ __launch_bounds__(128, 16)
__global__ void push_pull_halos_1d_kernel( __global__ void push_pull_halos_1d_kernel(
// top halo, // top halo,
const T* toh, int toh_stride_C, int toh_stride_H, int toh_stride_W, // top output halo const T* toh, int toh_stride_C, int toh_stride_H, int toh_stride_W, // top output halo
T* tox, int tox_stride_C, int tox_stride_H, int tox_stride_W, // top tx buffer T* tox, int tox_stride_C, int tox_stride_H, int tox_stride_W, // top output tx buffer
T* tix, int tix_stride_C, int tix_stride_H, int tix_stride_W, // top input tx buffer
T* tih, int tih_stride_C, int tih_stride_H, int tih_stride_W, // top input halo T* tih, int tih_stride_C, int tih_stride_H, int tih_stride_W, // top input halo
// btm halo // btm halo
const T* boh, int boh_stride_C, int boh_stride_H, int boh_stride_W, // top output halo const T* boh, int boh_stride_C, int boh_stride_H, int boh_stride_W, // btm output halo
T* box, int box_stride_C, int box_stride_H, int box_stride_W, // top tx buffer T* box, int box_stride_C, int box_stride_H, int box_stride_W, // btm output tx buffer
T* bih, int bih_stride_C, int bih_stride_H, int bih_stride_W, // top input halo T* bix, int bix_stride_C, int bix_stride_H, int bix_stride_W, // btm input tx buffer
T* bih, int bih_stride_C, int bih_stride_H, int bih_stride_W, // btm input halo
// dimensions // dimensions
int NC, int NH, int NW, int NC, int NH, int NW,
// signals // signals
...@@ -194,11 +207,11 @@ __global__ void push_pull_halos_1d_kernel( ...@@ -194,11 +207,11 @@ __global__ void push_pull_halos_1d_kernel(
strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW); strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, boh, boh_stride_C, boh_stride_H, boh_stride_W, NC, NH, NW);
// signal to top and btm neigbhbors that output halos are ready to be read // signal to top and btm neigbhbors that output halos are ready to be read
// the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values // the choice of values for v1-v4 is arbitrary and does not matter, as long as all ranks use the same values
dual_signal_wait_clear(signal1_flag, wait1_flag, signal2_flag, wait2_flag, -987751720, 840868300, -225529332, 281513358, true); dual_signal_wait_clear<true,true>(signal1_flag, wait1_flag, signal2_flag, wait2_flag, -987751720, 840868300, -225529332, 281513358);
// pull top halo from transfer buffer in peer memory to input // pull top halo from transfer buffer in peer memory to input
strided_copy_kernel<T,is_HWC>(tox, tox_stride_C, tox_stride_H, tox_stride_W, tih, tih_stride_C, tih_stride_H, tih_stride_W, NC, NH, NW); strided_copy_kernel<T,is_HWC>(tih, tih_stride_C, tih_stride_H, tih_stride_W, tix, tix_stride_C, tix_stride_H, tix_stride_W, NC, NH, NW);
// pull btm halo from transfer buffer in peer memory to input // pull btm halo from transfer buffer in peer memory to input
strided_copy_kernel<T,is_HWC>(box, box_stride_C, box_stride_H, box_stride_W, bih, bih_stride_C, bih_stride_H, bih_stride_W, NC, NH, NW); strided_copy_kernel<T,is_HWC>(bih, bih_stride_C, bih_stride_H, bih_stride_W, bix, bix_stride_C, bix_stride_H, bix_stride_W, NC, NH, NW);
} }
} }
...@@ -246,29 +259,32 @@ std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int6 ...@@ -246,29 +259,32 @@ std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int6
return results; return results;
} }
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape) at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{ {
return blob_view<at::Half>((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA)); return blob_view<at::Half>((at::Half*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA), channels_last);
} }
at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape) at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{ {
return blob_view<float>((float*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA)); return blob_view<float>((float*)raw, shape, torch::dtype(torch::kFloat32).device(torch::kCUDA), channels_last);
} }
at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape) at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape, bool channels_last)
{ {
return blob_view<int>((int*)raw, shape, torch::dtype(torch::kFloat16).device(torch::kCUDA)); return blob_view<int>((int*)raw, shape, torch::dtype(torch::kInt32).device(torch::kCUDA), channels_last);
} }
void push_pull_halos_1d( void push_pull_halos_1d(
bool diagnostics,
bool explicit_nhwc, bool explicit_nhwc,
int numSM, // number of SMs to use int numSM, // number of SMs to use
at::Tensor top_out_halo, // top output halo in sender device memory at::Tensor top_out_halo, // top output halo in sender device memory
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
at::Tensor top_inp_halo, // top input halo in receiver device memory at::Tensor top_inp_halo, // top input halo in receiver device memory
at::Tensor btm_out_halo, // btm output halo in sender device memory at::Tensor btm_out_halo, // btm output halo in sender device memory
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
at::Tensor btm_inp_halo, // btm input halo in receiver device memory at::Tensor btm_inp_halo, // btm input halo in receiver device memory
at::Tensor top_signal, // top input signal in receiver device memory at::Tensor top_signal, // top input signal in receiver device memory
at::Tensor btm_signal, // btm input signal in receiver device memory at::Tensor btm_signal, // btm input signal in receiver device memory
...@@ -278,9 +294,11 @@ void push_pull_halos_1d( ...@@ -278,9 +294,11 @@ void push_pull_halos_1d(
// basic checks of inputs // basic checks of inputs
TORCH_CHECK(top_out_halo.is_cuda()); TORCH_CHECK(top_out_halo.is_cuda());
TORCH_CHECK(top_out_tx.is_cuda()); TORCH_CHECK(top_out_tx.is_cuda());
TORCH_CHECK(top_inp_tx.is_cuda());
TORCH_CHECK(top_inp_halo.is_cuda()); TORCH_CHECK(top_inp_halo.is_cuda());
TORCH_CHECK(btm_out_halo.is_cuda()); TORCH_CHECK(btm_out_halo.is_cuda());
TORCH_CHECK(btm_out_tx.is_cuda()); TORCH_CHECK(btm_out_tx.is_cuda());
TORCH_CHECK(btm_inp_tx.is_cuda());
TORCH_CHECK(btm_inp_halo.is_cuda()); TORCH_CHECK(btm_inp_halo.is_cuda());
TORCH_CHECK(top_signal.is_cuda()); TORCH_CHECK(top_signal.is_cuda());
TORCH_CHECK(btm_signal.is_cuda()); TORCH_CHECK(btm_signal.is_cuda());
...@@ -291,46 +309,56 @@ void push_pull_halos_1d( ...@@ -291,46 +309,56 @@ void push_pull_halos_1d(
tensor_shape(top_out_halo, explicit_nhwc, toh_N, toh_C, toh_H, toh_W); tensor_shape(top_out_halo, explicit_nhwc, toh_N, toh_C, toh_H, toh_W);
int tox_N, tox_C, tox_H, tox_W; int tox_N, tox_C, tox_H, tox_W;
tensor_shape(top_out_tx, explicit_nhwc, tox_N, tox_C, tox_H, tox_W); tensor_shape(top_out_tx, explicit_nhwc, tox_N, tox_C, tox_H, tox_W);
int tix_N, tix_C, tix_H, tix_W;
tensor_shape(top_inp_tx, explicit_nhwc, tix_N, tix_C, tix_H, tix_W);
int tih_N, tih_C, tih_H, tih_W; int tih_N, tih_C, tih_H, tih_W;
tensor_shape(top_inp_halo, explicit_nhwc, tih_N, tih_C, tih_H, tih_W); tensor_shape(top_inp_halo, explicit_nhwc, tih_N, tih_C, tih_H, tih_W);
TORCH_CHECK( TORCH_CHECK(
(toh_N == tox_N && tox_N == tih_N) && (toh_N == tox_N && tox_N == tix_N && tix_N == tih_N) &&
(toh_C == tox_C && tox_C == tih_C) && (toh_C == tox_C && tox_C == tix_C && tix_C == tih_C) &&
(toh_H == tox_H && tox_H == tih_H) && (toh_H == tox_H && tox_H == tix_H && tix_H == tih_H) &&
(toh_W == tox_W && tox_W == tih_W)); (toh_W == tox_W && tox_W == tix_W && tix_W == tih_W));
int boh_N, boh_C, boh_H, boh_W; int boh_N, boh_C, boh_H, boh_W;
tensor_shape(btm_out_halo, explicit_nhwc, boh_N, boh_C, boh_H, boh_W); tensor_shape(btm_out_halo, explicit_nhwc, boh_N, boh_C, boh_H, boh_W);
int box_N, box_C, box_H, box_W; int box_N, box_C, box_H, box_W;
tensor_shape(btm_out_tx, explicit_nhwc, box_N, box_C, box_H, box_W); tensor_shape(btm_out_tx, explicit_nhwc, box_N, box_C, box_H, box_W);
int bix_N, bix_C, bix_H, bix_W;
tensor_shape(btm_inp_tx, explicit_nhwc, bix_N, bix_C, bix_H, bix_W);
int bih_N, bih_C, bih_H, bih_W; int bih_N, bih_C, bih_H, bih_W;
tensor_shape(btm_inp_halo, explicit_nhwc, bih_N, bih_C, bih_H, bih_W); tensor_shape(btm_inp_halo, explicit_nhwc, bih_N, bih_C, bih_H, bih_W);
TORCH_CHECK( TORCH_CHECK(
(boh_N == box_N && box_N == bih_N) && (boh_N == box_N && box_N == bix_N && bix_N == bih_N) &&
(boh_C == box_C && box_C == bih_C) && (boh_C == box_C && box_C == bix_C && bix_C == bih_C) &&
(boh_H == box_H && box_H == bih_H) && (boh_H == box_H && box_H == bix_H && bix_H == bih_H) &&
(boh_W == box_W && box_W == bih_W)); (boh_W == box_W && box_W == bix_W && bix_W == bih_W));
TORCH_CHECK( TORCH_CHECK(
(toh_N == boh_N) && (toh_N == boh_N) &&
(toh_C == boh_C) && (toh_C == boh_C) &&
(toh_H == boh_H) && (toh_H == boh_H) &&
(toh_W == boh_W)); (toh_W == boh_W));
int NC=toh_C, NH=toh_H, NW=toh_W; int NC=toh_C, NH=toh_H, NW=toh_W;
if (diagnostics) printf("NC=%d, NH=%d, NW=%d\n",NC,NH,NW);
int toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W; int toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W;
tensor_strides(top_out_halo, explicit_nhwc, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W); tensor_strides(top_out_halo, explicit_nhwc, toh_stride_N, toh_stride_C, toh_stride_H, toh_stride_W);
int tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W; int tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W;
tensor_strides(top_out_tx, explicit_nhwc, tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W); tensor_strides(top_out_tx, explicit_nhwc, tox_stride_N, tox_stride_C, tox_stride_H, tox_stride_W);
int tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W;
tensor_strides(top_inp_tx, explicit_nhwc, tix_stride_N, tix_stride_C, tix_stride_H, tix_stride_W);
int tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W; int tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W;
tensor_strides(top_inp_halo, explicit_nhwc, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W); tensor_strides(top_inp_halo, explicit_nhwc, tih_stride_N, tih_stride_C, tih_stride_H, tih_stride_W);
int boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W; int boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W;
tensor_strides(btm_out_halo, explicit_nhwc, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W); tensor_strides(btm_out_halo, explicit_nhwc, boh_stride_N, boh_stride_C, boh_stride_H, boh_stride_W);
int box_stride_N, box_stride_C, box_stride_H, box_stride_W; int box_stride_N, box_stride_C, box_stride_H, box_stride_W;
tensor_strides(btm_out_tx, explicit_nhwc, box_stride_N, box_stride_C, box_stride_H, box_stride_W); tensor_strides(btm_out_tx, explicit_nhwc, box_stride_N, box_stride_C, box_stride_H, box_stride_W);
int bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W;
tensor_strides(btm_inp_tx, explicit_nhwc, bix_stride_N, bix_stride_C, bix_stride_H, bix_stride_W);
int bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W; int bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W;
tensor_strides(btm_inp_halo, explicit_nhwc, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W); tensor_strides(btm_inp_halo, explicit_nhwc, bih_stride_N, bih_stride_C, bih_stride_H, bih_stride_W);
// determine if nhwc // determine if nhwc
auto is_nhwc = (toh_stride_C == 1) ? true : false; auto is_nhwc = (toh_stride_C == 1) ? true : false;
if (diagnostics) printf("is_nhwc = %s\n",is_nhwc?"true":"false");
// figure out launch parameters // figure out launch parameters
int device; int device;
...@@ -342,35 +370,59 @@ void push_pull_halos_1d( ...@@ -342,35 +370,59 @@ void push_pull_halos_1d(
const int numThreads = 128; const int numThreads = 128;
dim3 block(numThreads,1,1); dim3 block(numThreads,1,1);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, top_out_halo.scalar_type(), "push_pull_halos_1d_kernel", [&]{ AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, top_out_halo.scalar_type(), "push_pull_halos_1d_kernel", [&]{
if (diagnostics) printf("size(scalar_t) = %d\n",sizeof(scalar_t));
scalar_t* toh_p = top_out_halo.data_ptr<scalar_t>(); scalar_t* toh_p = top_out_halo.data_ptr<scalar_t>();
scalar_t* tox_p = top_out_tx.data_ptr<scalar_t>(); scalar_t* tox_p = top_out_tx.data_ptr<scalar_t>();
scalar_t* tix_p = top_inp_tx.data_ptr<scalar_t>();
scalar_t* tih_p = top_inp_halo.data_ptr<scalar_t>(); scalar_t* tih_p = top_inp_halo.data_ptr<scalar_t>();
scalar_t* boh_p = btm_out_halo.data_ptr<scalar_t>(); scalar_t* boh_p = btm_out_halo.data_ptr<scalar_t>();
scalar_t* box_p = btm_out_tx.data_ptr<scalar_t>(); scalar_t* box_p = btm_out_tx.data_ptr<scalar_t>();
scalar_t* bix_p = btm_inp_tx.data_ptr<scalar_t>();
scalar_t* bih_p = btm_inp_halo.data_ptr<scalar_t>(); scalar_t* bih_p = btm_inp_halo.data_ptr<scalar_t>();
int* top_signal_p = top_signal.data_ptr<int>(); if (diagnostics) printf("waypoint1\n");
int* btm_signal_p = btm_signal.data_ptr<int>() + 4; int* top_signal_p = top_signal.data_ptr<int>() + 4;
int* btm_signal_p = btm_signal.data_ptr<int>();
int* top_wait_p = waits.data_ptr<int>(); int* top_wait_p = waits.data_ptr<int>();
int* btm_wait_p = waits.data_ptr<int>() + 4; int* btm_wait_p = waits.data_ptr<int>() + 4;
if (diagnostics) printf("waypoint2\n");
// do int4 vector loads if channel count permits // do int4 vector loads if channel count permits
int elem_size_in_bytes = toh_C * sizeof(scalar_t); int elem_size_in_bytes = toh_C * sizeof(scalar_t);
int elem_size_in_int4 = (elem_size_in_bytes / 16); int elem_size_in_int4 = (elem_size_in_bytes / 16);
if (diagnostics) printf("elem_size_in_bytes = %d, elem_size_in_int4 = %d\n",elem_size_in_bytes,elem_size_in_int4);
if (is_nhwc && elem_size_in_int4*16 == elem_size_in_bytes) { if (is_nhwc && elem_size_in_int4*16 == elem_size_in_bytes) {
// can do int4 transfers // can do int4 transfers
int divisor = elem_size_in_bytes / elem_size_in_int4; int divisor = toh_C / elem_size_in_int4;
if (diagnostics) printf("CAN DO INT4 :: divisor = %d\n",divisor);
toh_stride_N /= divisor; toh_stride_H /= divisor; toh_stride_W /= divisor; toh_stride_N /= divisor; toh_stride_H /= divisor; toh_stride_W /= divisor;
tox_stride_N /= divisor; tox_stride_H /= divisor; tox_stride_W /= divisor; tox_stride_N /= divisor; tox_stride_H /= divisor; tox_stride_W /= divisor;
tix_stride_N /= divisor; tix_stride_H /= divisor; tix_stride_W /= divisor;
tih_stride_N /= divisor; tih_stride_H /= divisor; tih_stride_W /= divisor; tih_stride_N /= divisor; tih_stride_H /= divisor; tih_stride_W /= divisor;
boh_stride_N /= divisor; boh_stride_H /= divisor; boh_stride_W /= divisor; boh_stride_N /= divisor; boh_stride_H /= divisor; boh_stride_W /= divisor;
box_stride_N /= divisor; box_stride_H /= divisor; box_stride_W /= divisor; box_stride_N /= divisor; box_stride_H /= divisor; box_stride_W /= divisor;
bix_stride_N /= divisor; bix_stride_H /= divisor; bix_stride_W /= divisor;
bih_stride_N /= divisor; bih_stride_H /= divisor; bih_stride_W /= divisor; bih_stride_N /= divisor; bih_stride_H /= divisor; bih_stride_W /= divisor;
NC /= divisor;
if (diagnostics) {
printf("divisor=%d\n",divisor);
printf("toh_stride :: N=%d, C=%d, H=%d, W=%d\n",toh_stride_N,toh_stride_C,toh_stride_H,toh_stride_W);
printf("tox_stride :: N=%d, C=%d, H=%d, W=%d\n",tox_stride_N,tox_stride_C,tox_stride_H,tox_stride_W);
printf("tix_stride :: N=%d, C=%d, H=%d, W=%d\n",tix_stride_N,tix_stride_C,tix_stride_H,tix_stride_W);
printf("tih_stride :: N=%d, C=%d, H=%d, W=%d\n",tih_stride_N,tih_stride_C,tih_stride_H,tih_stride_W);
printf("boh_stride :: N=%d, C=%d, H=%d, W=%d\n",boh_stride_N,boh_stride_C,boh_stride_H,boh_stride_W);
printf("box_stride :: N=%d, C=%d, H=%d, W=%d\n",box_stride_N,box_stride_C,box_stride_H,box_stride_W);
printf("bix_stride :: N=%d, C=%d, H=%d, W=%d\n",bix_stride_N,bix_stride_C,bix_stride_H,bix_stride_W);
printf("bih_stride :: N=%d, C=%d, H=%d, W=%d\n",bih_stride_N,bih_stride_C,bih_stride_H,bih_stride_W);
printf("NC=%d, NH=%d, NW=%d\n",NC,NH,NW);
}
void *kernelArgs[] = { void *kernelArgs[] = {
(int4**)&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W, (int4**)&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W,
(int4**)&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W, (int4**)&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W,
(int4**)&tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W,
(int4**)&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W, (int4**)&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W,
(int4**)&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W, (int4**)&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W,
(int4**)&box_p, &box_stride_C, &box_stride_H, &box_stride_W, (int4**)&box_p, &box_stride_C, &box_stride_H, &box_stride_W,
(int4**)&bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W,
(int4**)&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W, (int4**)&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W,
&NC, &NH, &NW, &NC, &NH, &NW,
&top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p &top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
...@@ -381,12 +433,15 @@ void push_pull_halos_1d( ...@@ -381,12 +433,15 @@ void push_pull_halos_1d(
cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true>, grid, block, kernelArgs, 0, current_stream); cudaLaunchCooperativeKernel((void*)push_pull_halos_1d_kernel<int4,true>, grid, block, kernelArgs, 0, current_stream);
} else { } else {
// cannot do int4 transfers // cannot do int4 transfers
if (diagnostics) printf("CAN NOT DO INT4\n");
void *kernelArgs[] = { void *kernelArgs[] = {
&toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W, &toh_p, &toh_stride_C, &toh_stride_H, &toh_stride_W,
&tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W, &tox_p, &tox_stride_C, &tox_stride_H, &tox_stride_W,
&tix_p, &tix_stride_C, &tix_stride_H, &tix_stride_W,
&tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W, &tih_p, &tih_stride_C, &tih_stride_H, &tih_stride_W,
&boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W, &boh_p, &boh_stride_C, &boh_stride_H, &boh_stride_W,
&box_p, &box_stride_C, &box_stride_H, &box_stride_W, &box_p, &box_stride_C, &box_stride_H, &box_stride_W,
&bix_p, &bix_stride_C, &bix_stride_H, &bix_stride_W,
&bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W, &bih_p, &bih_stride_C, &bih_stride_H, &bih_stride_W,
&NC, &NH, &NW, &NC, &NH, &NW,
&top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p &top_signal_p, &btm_signal_p, &top_wait_p, &btm_wait_p
......
...@@ -24,17 +24,20 @@ namespace apex { namespace peer_memory { ...@@ -24,17 +24,20 @@ namespace apex { namespace peer_memory {
void free_raw(int64_t raw); void free_raw(int64_t raw);
at::Tensor get_raw_ipc_address(int64_t raw); at::Tensor get_raw_ipc_address(int64_t raw);
std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw); std::vector<int64_t> get_raw_peers(at::Tensor ipc_addresses, int peer_rank, int64_t raw);
at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape); at::Tensor blob_view_half(int64_t raw, std::vector<int64_t> shape, bool channels_last);
at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape); at::Tensor blob_view_float(int64_t raw, std::vector<int64_t> shape, bool channels_last);
at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape); at::Tensor blob_view_int(int64_t raw, std::vector<int64_t> shape, bool channels_last);
void push_pull_halos_1d( void push_pull_halos_1d(
bool diagnostics,
bool explicit_nhwc, bool explicit_nhwc,
int numSM, // number of SMs to use int numSM, // number of SMs to use
at::Tensor top_out_halo, // top output halo in sender device memory at::Tensor top_out_halo, // top output halo in sender device memory
at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory at::Tensor top_out_tx, // top output transfer buffer in sender peer pool memory
at::Tensor top_inp_tx, // top input transfer buffer in top neighbor peer pool memory
at::Tensor top_inp_halo, // top input halo in receiver device memory at::Tensor top_inp_halo, // top input halo in receiver device memory
at::Tensor btm_out_halo, // btm output halo in sender device memory at::Tensor btm_out_halo, // btm output halo in sender device memory
at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory at::Tensor btm_out_tx, // btm output transfer buffer in sender peer pool memory
at::Tensor btm_inp_tx, // btm input transfer buffer in btm neighbor peer pool memory
at::Tensor btm_inp_halo, // btm input halo in receiver device memory at::Tensor btm_inp_halo, // btm input halo in receiver device memory
at::Tensor top_signal, // top input signal in receiver device memory at::Tensor top_signal, // top input signal in receiver device memory
at::Tensor btm_signal, // btm input signal in receiver device memory at::Tensor btm_signal, // btm input signal in receiver device memory
......
import torch import torch
import numpy as np import numpy as np
import peer_memory import peer_memory as pm
class PeerMemoryPool(object): class PeerMemoryPool(object):
def __init__(self, rank, world_size, peer_group_size, size): def __init__(self, rank, world_size, peer_group_size, static_size, dynamic_size):
self.peer_group = rank // peer_group_size self.peer_group = rank // peer_group_size
self.peer_rank = rank % peer_group_size self.peer_rank = rank % peer_group_size
self.peer_group_size = peer_group_size self.peer_group_size = peer_group_size
self.alignment = 256 self.alignment = 256
self.size = size self.static_size = ((static_size + self.alignment - 1) // self.alignment) * self.alignment
self.dynamic_size = ((dynamic_size + self.alignment - 1) // self.alignment) * self.alignment
# allocate giant pool of device memory # allocate giant pool of device memory
self.raw = allocate_raw(size) self.raw = pm.allocate_raw(self.static_size+self.dynamic_size)
# exchange peer pointers with nccl # exchange peer pointers with nccl
raw_ipc = get_raw_ipc_address(self.raw).cuda() raw_ipc = pm.get_raw_ipc_address(self.raw).cuda()
peer_raw_ipcs = [torch.empty_like(raw_ipc) for _ in range(world_size)] peer_raw_ipcs = [torch.empty_like(raw_ipc) for _ in range(world_size)]
torch.distributed.all_gather(peer_raw_ipcs, raw_ipc) torch.distributed.all_gather(peer_raw_ipcs, raw_ipc)
peer_raw_ipcs = torch.stack(peer_raw_ipcs).cpu() peer_raw_ipcs = torch.stack(peer_raw_ipcs).cpu()
self.peer_raw = get_raw_peers(peer_raw_ipcs, self.peer_rank, self.raw) self.peer_raw = pm.get_raw_peers(peer_raw_ipcs, self.peer_rank, self.raw)
self.current = 0 self.static_offset = 0
self.dynamic_offset = 0
def __del__(self): def __del__(self):
free_raw(self.raw) pm.free_raw(self.raw)
def reset(self): def reset(self):
self.current = 0 self.dynamic_offset = 0
def allocate_peer_tensors(self, shape, dtype): def allocate_peer_tensors(self, shape, dtype, channels_last, dynamic):
nels = np.prod(shape) nels = np.prod(shape)
if dtype == torch.float16: if dtype == torch.float16:
elem_size = 2 elem_size = 2
start = ((self.current + self.alignment - 1) // self.alignment) * self.alignment if dynamic:
self.current = start + nels * elem_size start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
assert(self.current < self.size), "Peer memory pool exhausted" self.dynamic_offset = start + nels * elem_size
return [blob_view_half(pr + start, shape) for pr in self.peer_raw] assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
elif dtype == torch.float32: return [pm.blob_view_half(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_half(pr + start, shape, channels_last) for pr in self.peer_raw]
if dtype == torch.float32:
elem_size = 4 elem_size = 4
start = ((self.current + self.alignment - 1) // self.alignment) * self.alignment if dynamic:
self.current = start + nels * elem_size start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
assert(self.current < self.size), "Peer memory pool exhausted" self.dynamic_offset = start + nels * elem_size
return [blob_view_float(pr + start, shape) for pr in self.peer_raw] assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
elif dtype == torch.int32: return [pm.blob_view_float(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_float(pr + start, shape, channels_last) for pr in self.peer_raw]
if dtype == torch.int32:
elem_size = 4 elem_size = 4
start = ((self.current + self.alignment - 1) // self.alignment) * self.alignment if dynamic:
self.current = start + nels * elem_size start = ((self.dynamic_offset + self.alignment - 1) // self.alignment) * self.alignment
assert(self.current < self.size), "Peer memory pool exhausted" self.dynamic_offset = start + nels * elem_size
return [blob_view_int(pr + start, shape) for pr in self.peer_raw] assert(self.dynamic_offset < self.dynamic_size), "Dynamic peer memory pool exhausted"
return [pm.blob_view_int(pr + self.static_size + start, shape, channels_last) for pr in self.peer_raw]
else:
start = ((self.static_offset + self.alignment - 1) // self.alignment) * self.alignment
self.static_offset = start + nels * elem_size
assert(self.static_offset < self.static_size), "Static peer memory pool exhausted"
return [pm.blob_view_int(pr + start, shape, channels_last) for pr in self.peer_raw]
else: else:
assert(False), "Unknown dtype : %s" % (str(dtype)) assert(False), "dtype %s not supported" % (str(dtype))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment