Commit 1d391bba authored by rprenger's avatar rprenger
Browse files

Addressing comments

parent b0c824d9
<!-- coding=utf-8-->
<!-- Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.-->
<!---->
<!-- Licensed under the Apache License, Version 2.0 (the "License");-->
<!-- you may not use this file except in compliance with the License.-->
<!-- You may obtain a copy of the License at-->
<!---->
<!-- http://www.apache.org/licenses/LICENSE-2.0-->
<!---->
<!-- Unless required by applicable law or agreed to in writing, software-->
<!-- distributed under the License is distributed on an "AS IS" BASIS,-->
<!-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.-->
<!-- See the License for the specific language governing permissions and-->
<!-- limitations under the License.-->
<!DOCTYPE html> <!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
......
...@@ -234,6 +234,8 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -234,6 +234,8 @@ def generate_tokens_probs_and_return_on_first_stage(
# Check if all the sequences have hit the termination_id. # Check if all the sequences have hit the termination_id.
done = None done = None
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
# TODO(rprenger) These stopping methods are tokenizer dependent
# instead tokenization should be in the inference loop so stop sequences can be used
if stop_on_double_eol: if stop_on_double_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte() hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte() hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment