"examples/pytorch/vscode:/vscode.git/clone" did not exist on "a90296aa9c807acc391e8ef0f4aaa2052de24923"
Unverified Commit 2bbca12a authored by nv-dlasalle's avatar nv-dlasalle Committed by GitHub
Browse files

[bugfix] Fix assertions in /src/runtime/workspace.h and expand unit tests for...

[bugfix] Fix assertions in /src/runtime/workspace.h and expand unit tests for sparse optimizer (#5299)

* Fix assertions for size 0 workspaces

* Expand unit test to cover case of communication

* Fixes

* Format

* Fix c++ formatting
parent a5e31391
...@@ -20,8 +20,8 @@ class Workspace { ...@@ -20,8 +20,8 @@ class Workspace {
Workspace(DeviceAPI* device, DGLContext ctx, const size_t size) Workspace(DeviceAPI* device, DGLContext ctx, const size_t size)
: device_(device), : device_(device),
ctx_(ctx), ctx_(ctx),
ptr_(static_cast<T*>(device_->AllocWorkspace(ctx_, sizeof(T) * size))) { size_(size * sizeof(T)),
} ptr_(static_cast<T*>(device_->AllocWorkspace(ctx_, size_))) {}
~Workspace() { ~Workspace() {
if (*this) { if (*this) {
...@@ -32,17 +32,17 @@ class Workspace { ...@@ -32,17 +32,17 @@ class Workspace {
operator bool() const { return ptr_ != nullptr; } operator bool() const { return ptr_ != nullptr; }
T* get() { T* get() {
assert(*this); assert(size_ == 0 || *this);
return ptr_; return ptr_;
} }
T const* get() const { T const* get() const {
assert(*this); assert(size_ == 0 || *this);
return ptr_; return ptr_;
} }
void free() { void free() {
assert(*this); assert(size_ == 0 || *this);
device_->FreeWorkspace(ctx_, ptr_); device_->FreeWorkspace(ctx_, ptr_);
ptr_ = nullptr; ptr_ = nullptr;
} }
...@@ -50,6 +50,7 @@ class Workspace { ...@@ -50,6 +50,7 @@ class Workspace {
private: private:
DeviceAPI* device_; DeviceAPI* device_;
DGLContext ctx_; DGLContext ctx_;
size_t size_;
T* ptr_; T* ptr_;
}; };
...@@ -59,7 +60,8 @@ class Workspace<void> { ...@@ -59,7 +60,8 @@ class Workspace<void> {
Workspace(DeviceAPI* device, DGLContext ctx, const size_t size) Workspace(DeviceAPI* device, DGLContext ctx, const size_t size)
: device_(device), : device_(device),
ctx_(ctx), ctx_(ctx),
ptr_(static_cast<void*>(device_->AllocWorkspace(ctx_, size))) {} size_(size),
ptr_(static_cast<void*>(device_->AllocWorkspace(ctx_, size_))) {}
~Workspace() { ~Workspace() {
if (*this) { if (*this) {
...@@ -70,17 +72,17 @@ class Workspace<void> { ...@@ -70,17 +72,17 @@ class Workspace<void> {
operator bool() const { return ptr_ != nullptr; } operator bool() const { return ptr_ != nullptr; }
void* get() { void* get() {
assert(*this); assert(size_ == 0 || *this);
return ptr_; return ptr_;
} }
void const* get() const { void const* get() const {
assert(*this); assert(size_ == 0 || *this);
return ptr_; return ptr_;
} }
void free() { void free() {
assert(*this); assert(size_ == 0 || *this);
device_->FreeWorkspace(ctx_, ptr_); device_->FreeWorkspace(ctx_, ptr_);
ptr_ = nullptr; ptr_ = nullptr;
} }
...@@ -88,6 +90,7 @@ class Workspace<void> { ...@@ -88,6 +90,7 @@ class Workspace<void> {
private: private:
DeviceAPI* device_; DeviceAPI* device_;
DGLContext ctx_; DGLContext ctx_;
size_t size_;
void* ptr_; void* ptr_;
}; };
......
...@@ -186,6 +186,7 @@ def start_sparse_adam_worker( ...@@ -186,6 +186,7 @@ def start_sparse_adam_worker(
backend="gloo", backend="gloo",
num_embs=128, num_embs=128,
emb_dim=10, emb_dim=10,
zero_comm=True,
): ):
print("start sparse worker for adam {}".format(rank)) print("start sparse worker for adam {}".format(rank))
dist_init_method = "tcp://{master_ip}:{master_port}".format( dist_init_method = "tcp://{master_ip}:{master_port}".format(
...@@ -218,10 +219,13 @@ def start_sparse_adam_worker( ...@@ -218,10 +219,13 @@ def start_sparse_adam_worker(
else: else:
dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01) dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)
th.manual_seed(rank)
if zero_comm:
start = (num_embs // world_size) * rank start = (num_embs // world_size) * rank
end = (num_embs // world_size) * (rank + 1) end = (num_embs // world_size) * (rank + 1)
th.manual_seed(rank)
idx = th.randint(start, end, size=(4,)).to(tensor_dev) idx = th.randint(start, end, size=(4,)).to(tensor_dev)
else:
idx = th.randint(0, num_embs, size=(4,)).to(tensor_dev)
dgl_value = dgl_emb(idx, device) dgl_value = dgl_emb(idx, device)
labels = th.ones((4,)).long().to(device) labels = th.ones((4,)).long().to(device)
dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels) dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
...@@ -240,7 +244,13 @@ def start_sparse_adam_worker( ...@@ -240,7 +244,13 @@ def start_sparse_adam_worker(
def start_torch_adam_worker( def start_torch_adam_worker(
rank, world_size, weight, has_zero_grad=False, num_embs=128, emb_dim=10 rank,
world_size,
weight,
has_zero_grad=False,
num_embs=128,
emb_dim=10,
zero_comm=True,
): ):
print("start sparse worker for adam {}".format(rank)) print("start sparse worker for adam {}".format(rank))
dist_init_method = "tcp://{master_ip}:{master_port}".format( dist_init_method = "tcp://{master_ip}:{master_port}".format(
...@@ -275,10 +285,13 @@ def start_torch_adam_worker( ...@@ -275,10 +285,13 @@ def start_torch_adam_worker(
list(torch_emb.module.parameters()), lr=0.01 list(torch_emb.module.parameters()), lr=0.01
) )
th.manual_seed(rank)
if zero_comm:
start = (num_embs // world_size) * rank start = (num_embs // world_size) * rank
end = (num_embs // world_size) * (rank + 1) end = (num_embs // world_size) * (rank + 1)
th.manual_seed(rank)
idx = th.randint(start, end, size=(4,)) idx = th.randint(start, end, size=(4,))
else:
idx = th.randint(0, num_embs, size=(4,))
labels = th.ones((4,)).long() labels = th.ones((4,)).long()
torch_value = torch_emb(idx) torch_value = torch_emb(idx)
torch_loss = th.nn.functional.cross_entropy(torch_value, labels) torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
...@@ -340,7 +353,8 @@ def test_multiprocess_cpu_sparse_adam(num_workers): ...@@ -340,7 +353,8 @@ def test_multiprocess_cpu_sparse_adam(num_workers):
@unittest.skipIf(F.ctx().type == "cpu", reason="gpu only test") @unittest.skipIf(F.ctx().type == "cpu", reason="gpu only test")
@pytest.mark.parametrize("num_workers", [2, 4, 8]) @pytest.mark.parametrize("num_workers", [2, 4, 8])
@pytest.mark.parametrize("backend", ["nccl", "gloo"]) @pytest.mark.parametrize("backend", ["nccl", "gloo"])
def test_multiprocess_sparse_adam(num_workers, backend): @pytest.mark.parametrize("zero_comm", [True, False])
def test_multiprocess_sparse_adam(num_workers, backend, zero_comm):
if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers: if F.ctx().type == "cuda" and th.cuda.device_count() < num_workers:
pytest.skip("Not enough GPUs to run test.") pytest.skip("Not enough GPUs to run test.")
...@@ -364,6 +378,9 @@ def test_multiprocess_sparse_adam(num_workers, backend): ...@@ -364,6 +378,9 @@ def test_multiprocess_sparse_adam(num_workers, backend):
th.device("cpu"), th.device("cpu"),
True, True,
backend, backend,
num_embs,
emb_dim,
zero_comm,
), ),
) )
p.start() p.start()
...@@ -376,7 +393,15 @@ def test_multiprocess_sparse_adam(num_workers, backend): ...@@ -376,7 +393,15 @@ def test_multiprocess_sparse_adam(num_workers, backend):
for i in range(num_workers): for i in range(num_workers):
p = ctx.Process( p = ctx.Process(
target=start_torch_adam_worker, target=start_torch_adam_worker,
args=(i, num_workers, torch_weight, False), args=(
i,
num_workers,
torch_weight,
False,
num_embs,
emb_dim,
zero_comm,
),
) )
p.start() p.start()
worker_list.append(p) worker_list.append(p)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment