"docs/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "16409ff8523df3bb69258c1773f3221c0ba906b3"
Unverified Commit 061c2a36 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

upd (#2352)

parent f8ebcd7f
...@@ -41,7 +41,7 @@ __global__ void GESpMMKernel( ...@@ -41,7 +41,7 @@ __global__ void GESpMMKernel(
DType accum_0 = 0., DType accum_0 = 0.,
accum_1 = 0.; accum_1 = 0.;
if (blockIdx.y != gridDim.y - 1) { if (blockIdx.y != gridDim.y - 1) { // fid + 32 < feat_len
for (Idx left = low; left < high; left += 32) { for (Idx left = low; left < high; left += 32) {
if (left + 32 <= high) { if (left + 32 <= high) {
#pragma unroll #pragma unroll
...@@ -76,39 +76,40 @@ __global__ void GESpMMKernel( ...@@ -76,39 +76,40 @@ __global__ void GESpMMKernel(
out[feat_len * rid + fid + 32] = accum_1; out[feat_len * rid + fid + 32] = accum_1;
} }
} else { } else {
bool right_inbound = fid + 32 < feat_len; const Idx fid_0 = fid < feat_len ? fid : 0,
fid_1 = fid + 32 < feat_len ? fid + 32 : 0;
for (int left = low; left < high; left += 32) { for (int left = low; left < high; left += 32) {
if (left + 32 <= high) { if (left + 32 <= high) {
#pragma unroll #pragma unroll
for (int i = 0; i < 32; ++i) { for (int i = 0; i < 32; ++i) {
const Idx eid = left + i; const Idx eid = left + i;
const Idx cid = __ldg(indices + eid); const Idx cid = __ldg(indices + eid);
const Idx offset = feat_len * cid + fid; const Idx offset = feat_len * cid;
if (BinaryOp::use_rhs) { if (BinaryOp::use_rhs) {
accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid); accum_0 += BinaryOp::Call(ufeat + offset + fid_0, efeat + eid);
accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid); accum_1 += BinaryOp::Call(ufeat + offset + fid_1, efeat + eid);
} else { } else {
accum_0 += ufeat[offset]; accum_0 += ufeat[offset + fid_0];
accum_1 += ufeat[offset + 32]; accum_1 += ufeat[offset + fid_1];
} }
} }
} else { } else {
for (int i = 0; i + left < high; ++i) { for (int i = 0; i + left < high; ++i) {
const Idx eid = left + i; const Idx eid = left + i;
const Idx cid = __ldg(indices + eid); const Idx cid = __ldg(indices + eid);
const Idx offset = feat_len * cid + fid; const Idx offset = feat_len * cid;
if (BinaryOp::use_rhs) { if (BinaryOp::use_rhs) {
accum_0 += BinaryOp::Call(ufeat + offset, efeat + eid); accum_0 += BinaryOp::Call(ufeat + offset + fid_0, efeat + eid);
accum_1 += BinaryOp::Call(ufeat + offset + 32, efeat + eid); accum_1 += BinaryOp::Call(ufeat + offset + fid_1, efeat + eid);
} else { } else {
accum_0 += ufeat[offset]; accum_0 += ufeat[offset + fid_0];
accum_1 += ufeat[offset + 32]; accum_1 += ufeat[offset + fid_1];
} }
} }
} }
out[feat_len * rid + fid] = accum_0; out[feat_len * rid + fid] = accum_0;
if (right_inbound) if (fid + 32 < feat_len)
out[feat_len * rid + fid + 32] = accum_1; out[feat_len * rid + fid + 32] = accum_1;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment