Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
41276b6c
Commit
41276b6c
authored
Oct 03, 2022
by
Vijay Korthikanti
Browse files
Merge branch 'main' into nmt-main
parents
a44360ed
fc7f4f03
Changes
135
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
29 additions
and
289 deletions
+29
-289
tasks/main.py
tasks/main.py
+1
-14
tasks/msdp/evaluate.py
tasks/msdp/evaluate.py
+1
-14
tasks/msdp/main.py
tasks/msdp/main.py
+1
-14
tasks/msdp/preprocessing.py
tasks/msdp/preprocessing.py
+1
-14
tasks/msdp/prompt.py
tasks/msdp/prompt.py
+1
-14
tasks/orqa/evaluate_orqa.py
tasks/orqa/evaluate_orqa.py
+1
-14
tasks/orqa/evaluate_utils.py
tasks/orqa/evaluate_utils.py
+1
-14
tasks/orqa/supervised/data.py
tasks/orqa/supervised/data.py
+1
-14
tasks/orqa/supervised/eval_utils.py
tasks/orqa/supervised/eval_utils.py
+1
-14
tasks/orqa/supervised/finetune.py
tasks/orqa/supervised/finetune.py
+2
-15
tasks/orqa/unsupervised/nq.py
tasks/orqa/unsupervised/nq.py
+1
-14
tasks/race/finetune.py
tasks/race/finetune.py
+1
-14
tasks/vision/classification/classification.py
tasks/vision/classification/classification.py
+2
-15
tasks/vision/classification/eval_utils.py
tasks/vision/classification/eval_utils.py
+1
-14
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+6
-19
tasks/vision/main.py
tasks/vision/main.py
+1
-14
tasks/vision/segmentation/finetune_segformer.py
tasks/vision/segmentation/finetune_segformer.py
+2
-15
tasks/vision/segmentation/finetune_setr.py
tasks/vision/segmentation/finetune_setr.py
+2
-15
tasks/vision/segmentation/seg_heads.py
tasks/vision/segmentation/seg_heads.py
+1
-14
tasks/vision/segmentation/seg_models.py
tasks/vision/segmentation/seg_models.py
+1
-14
No files found.
tasks/main.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
"""Main tasks functionality."""
...
...
tasks/msdp/evaluate.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model evaluation"""
"""Model evaluation"""
...
...
tasks/msdp/main.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run multi-stage dialogue prompting (MSDP)."""
"""Run multi-stage dialogue prompting (MSDP)."""
...
...
tasks/msdp/preprocessing.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
...
...
tasks/msdp/prompt.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prompting the pretrained language model to generate knowledge/response"""
"""Prompting the pretrained language model to generate knowledge/response"""
...
...
tasks/orqa/evaluate_orqa.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
"""Main tasks functionality."""
...
...
tasks/orqa/evaluate_utils.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch
...
...
tasks/orqa/supervised/data.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ORQA dataset."""
"""ORQA dataset."""
...
...
tasks/orqa/supervised/eval_utils.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation utilities."""
"""Evaluation utilities."""
from
collections
import
OrderedDict
from
collections
import
OrderedDict
...
...
tasks/orqa/supervised/finetune.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ORQA finetuning/evaluation."""
"""ORQA finetuning/evaluation."""
...
@@ -63,7 +50,7 @@ def orqa(Dataset):
...
@@ -63,7 +50,7 @@ def orqa(Dataset):
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
# Get the batch.
# Get the batch.
timers
(
'batch generator'
).
start
()
timers
(
'batch generator'
,
log_level
=
2
).
start
()
try
:
try
:
batch_
=
next
(
batch
)
batch_
=
next
(
batch
)
except
BaseException
:
except
BaseException
:
...
...
tasks/orqa/unsupervised/nq.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
"""
Data Loader for Google NQ dataset
Data Loader for Google NQ dataset
...
...
tasks/race/finetune.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Race."""
"""Race."""
...
...
tasks/vision/classification/classification.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
"""Vision-classification finetuning/evaluation."""
...
@@ -68,7 +55,7 @@ def classification():
...
@@ -68,7 +55,7 @@ def classification():
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
"batch generator"
).
start
()
timers
(
"batch generator"
,
log_level
=
2
).
start
()
try
:
try
:
batch_
=
next
(
batch
)
batch_
=
next
(
batch
)
except
BaseException
:
except
BaseException
:
...
...
tasks/vision/classification/eval_utils.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation utilities."""
"""Evaluation utilities."""
...
...
tasks/vision/finetune_utils.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetune utilities."""
"""Finetune utilities."""
...
@@ -136,7 +123,7 @@ def _train(
...
@@ -136,7 +123,7 @@ def _train(
report_memory_flag
=
True
report_memory_flag
=
True
# For each remaining epoch
# For each remaining epoch
timers
(
"interval-time"
).
start
(
)
timers
(
"interval-time"
,
log_level
=
0
).
start
(
barrier
=
True
)
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
print_rank_0
(
"working on epoch {} ..."
.
format
(
epoch
+
1
))
print_rank_0
(
"working on epoch {} ..."
.
format
(
epoch
+
1
))
...
@@ -218,7 +205,7 @@ def finetune(
...
@@ -218,7 +205,7 @@ def finetune(
timers
=
get_timers
()
timers
=
get_timers
()
# Train and validation data loaders.
# Train and validation data loaders.
timers
(
"train/valid/test dataset/dataloder"
).
start
()
timers
(
"train/valid/test dataset/dataloder"
,
log_level
=
0
).
start
()
if
args
.
epochs
>
0
:
if
args
.
epochs
>
0
:
train_dataset
,
valid_dataset
=
train_valid_datasets_provider
()
train_dataset
,
valid_dataset
=
train_valid_datasets_provider
()
train_dataloader
,
valid_dataloader
=
_build_train_valid_dataloaders
(
train_dataloader
,
valid_dataloader
=
_build_train_valid_dataloaders
(
...
@@ -227,14 +214,14 @@ def finetune(
...
@@ -227,14 +214,14 @@ def finetune(
timers
(
"train/valid/test dataset/dataloder"
).
stop
()
timers
(
"train/valid/test dataset/dataloder"
).
stop
()
# Build calback function.
# Build calback function.
timers
(
"callback function"
).
start
()
timers
(
"callback function"
,
log_level
=
0
).
start
()
end_of_epoch_callback
=
None
end_of_epoch_callback
=
None
if
end_of_epoch_callback_provider
is
not
None
:
if
end_of_epoch_callback_provider
is
not
None
:
end_of_epoch_callback
=
end_of_epoch_callback_provider
()
end_of_epoch_callback
=
end_of_epoch_callback_provider
()
timers
(
"callback function"
).
stop
()
timers
(
"callback function"
).
stop
()
# Build model, optimizer and learning rate scheduler.
# Build model, optimizer and learning rate scheduler.
timers
(
"model and optimizer"
).
start
()
timers
(
"model and optimizer"
,
log_level
=
0
).
start
()
model
,
optimizer
,
opt_param_scheduler
=
\
model
,
optimizer
,
opt_param_scheduler
=
\
setup_model_and_optimizer
(
setup_model_and_optimizer
(
model_provider
,
model_provider
,
...
@@ -246,7 +233,7 @@ def finetune(
...
@@ -246,7 +233,7 @@ def finetune(
# If pretrained checkpoint is provided and we have not trained for
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
# checkpoint.
timers
(
"pretrained checkpoint"
).
start
(
)
timers
(
"pretrained checkpoint"
,
log_level
=
0
).
start
(
barrier
=
True
)
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
if
args
.
pretrained_checkpoint_type
==
'default'
:
if
args
.
pretrained_checkpoint_type
==
'default'
:
original_load
=
args
.
load
original_load
=
args
.
load
...
...
tasks/vision/main.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
"""Main tasks functionality."""
...
...
tasks/vision/segmentation/finetune_segformer.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
"""Vision-classification finetuning/evaluation."""
...
@@ -123,7 +110,7 @@ def segmentation():
...
@@ -123,7 +110,7 @@ def segmentation():
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
"batch generator"
).
start
()
timers
(
"batch generator"
,
log_level
=
2
).
start
()
import
types
import
types
if
isinstance
(
batch
,
types
.
GeneratorType
):
if
isinstance
(
batch
,
types
.
GeneratorType
):
batch_
=
next
(
batch
)
batch_
=
next
(
batch
)
...
...
tasks/vision/segmentation/finetune_setr.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
"""Vision-classification finetuning/evaluation."""
...
@@ -86,7 +73,7 @@ def segmentation():
...
@@ -86,7 +73,7 @@ def segmentation():
timers
=
get_timers
()
timers
=
get_timers
()
# Get the batch.
# Get the batch.
timers
(
"batch generator"
).
start
()
timers
(
"batch generator"
,
log_level
=
2
).
start
()
import
types
import
types
if
isinstance
(
batch
,
types
.
GeneratorType
):
if
isinstance
(
batch
,
types
.
GeneratorType
):
batch_
=
next
(
batch
)
batch_
=
next
(
batch
)
...
...
tasks/vision/segmentation/seg_heads.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
math
import
einops
import
einops
import
torch
import
torch
...
...
tasks/vision/segmentation/seg_models.py
View file @
41276b6c
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
math
import
einops
import
einops
import
torch
import
torch
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment