Unverified Commit 2bbca12a authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[bugfix] Fix assertions in /src/runtime/workspace.h and expand unit tests for...

[bugfix] Fix assertions in /src/runtime/workspace.h and expand unit tests for sparse optimizer (#5299)

* Fix assertions for size 0 workspaces

* Expand unit test to cover case of communication

* Fixes

* Format

* Fix c++ formatting
parent a5e31391
......@@ -20,8 +20,8 @@ class Workspace {
Workspace(DeviceAPI* device, DGLContext ctx, const size_t size)
: device_(device),
ctx_(ctx),
ptr_(static_cast<T*>(device_->AllocWorkspace(ctx_, sizeof(T) * size))) {
}
size_(size * sizeof(T)),
ptr_(static_cast<T*>(device_->AllocWorkspace(ctx_, size_))) {}
~Workspace() {
if (*this) {
......@@ -32,17 +32,17 @@ class Workspace {
operator bool() const { return ptr_ != nullptr; }
T* get() {
assert(*this);
assert(size_ == 0 || *this);
return ptr_;
}
T const* get() const {
assert(*this);
assert(size_ == 0 || *this);
return ptr_;
}
void free() {
assert(*this);
assert(size_ == 0 || *this);
device_->FreeWorkspace(ctx_, ptr_);
ptr_ = nullptr;
}
......@@ -50,6 +50,7 @@ class Workspace {
private:
DeviceAPI* device_;
DGLContext ctx_;
size_t size_;
T* ptr_;
};
......@@ -59,7 +60,8 @@ class Workspace<void> {
Workspace(DeviceAPI* device, DGLContext ctx, const size_t size)
: device_(device),
ctx_(ctx),
ptr_(static_cast<void*>(device_->AllocWorkspace(ctx_, size))) {}
size_(size),
ptr_(static_cast<void*>(device_->AllocWorkspace(ctx_, size_))) {}
~Workspace() {
if (*this) {
......@@ -70,17 +72,17 @@ class Workspace<void> {
operator bool() const { return ptr_ != nullptr; }
void* get() {
assert(*this);
assert(size_ == 0 || *this);
return ptr_;
}
void const* get() const {
assert(*this);
assert(size_ == 0 || *this);
return ptr_;
}
void free() {
assert(*this);
assert(size_ == 0 || *this);
device_->FreeWorkspace(ctx_, ptr_);
ptr_ = nullptr;
}
......@@ -88,6 +90,7 @@ class Workspace<void> {
private:
DeviceAPI* device_;
DGLContext ctx_;
size_t size_;
void* ptr_;
};
......
......@@ -186,6 +186,7 @@ def start_sparse_adam_worker(
backend="gloo",
num_embs=128,
emb_dim=10,
zero_comm=True,
):
print("start sparse worker for adam {}".format(rank))
dist_init_method = "tcp://{master_ip}:{master_port}".format(
......@@ -218,10 +219,13 @@ def start_sparse_adam_worker(
else:
dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)
start = (num_embs // world_size) * rank
end = (num_embs // world_size) * (rank + 1)
th.manual_seed(rank)
idx = th.randint(start, end, size=(4,)).to(tensor_dev)
if zero_comm:
start = (num_embs // world_size) * rank
end = (num_embs // world_size) * (rank + 1)
idx = th.randint(start, end, size=(4,)).to(tensor_dev)
else:
idx = th.randint(0, num_embs, size=(4,)).to(tensor_dev)
dgl_value = dgl_emb(idx, device)
labels = th.ones((4,)).long().to(device)
dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
......@@ -240,7 +244,13 @@ def start_sparse_adam_worker(
def start_torch_adam_worker(
rank, world_size, weight, has_zero_grad=False, num_embs=128, emb_dim=10
rank,
world_size,
weight,
has_zero_grad=False,
num_embs=128,
emb_dim=10,
zero_comm=True,
):
print("start sparse worker for adam {}".format(rank))
dist_init_method = "tcp://{master_ip}:{master_port}".format(
......@@ -275,10 +285,13 @@ def start_torch_adam_worker(
list(torch_emb.module.parameters()), lr=0.01
)
start = (num_embs // world_size) * rank
end = (num_embs // world_size) * (rank + 1)
th.manual_seed(rank)
idx = th.randint(start, end, size=(4,))
if zero_comm:
start = (num_embs // world_size) * rank
end = (num_embs // world_size) * (rank + 1)
idx = th.randint(start, end, size=(4,))
else:
idx = th.randint(0, num_embs, size=(4,))
labels = th.ones((4,)).long()
torch_value = torch_emb(idx)
torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
......@@ -340,7 +353,8 @@ def test_multiprocess_cpu_sparse_adam(num_workers):
@unittest.skipIf(F.ctx().type == "cpu", reason="gpu only test")
@pytest.mark.parametrize("num_workers", [2, 4, 8])
@pytest.mark.parametrize("backend", ["nccl", "gloo"])
def test_multiprocess_sparse_adam(num_workers, backend):
@pytest.mark.parametrize("zero_comm", [True, False])
def test_multiprocess_sparse_adam(num_workers, backend, zero_comm):
if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
pytest.skip("Not enough GPUs to run test.")
......@@ -364,6 +378,9 @@ def test_multiprocess_sparse_adam(num_workers, backend):
th.device("cpu"),
True,
backend,
num_embs,
emb_dim,
zero_comm,
),
)
p.start()
......@@ -376,7 +393,15 @@ def test_multiprocess_sparse_adam(num_workers, backend):
for i in range(num_workers):
p = ctx.Process(
target=start_torch_adam_worker,
args=(i, num_workers, torch_weight, False),
args=(
i,
num_workers,
torch_weight,
False,
num_embs,
emb_dim,
zero_comm,
),
)
p.start()
worker_list.append(p)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment