Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
7a9c4a03
"tests/vscode:/vscode.git/clone" did not exist on "84c7dc3436ad775fb09ab5ab8ac8028f5caee9d2"
Commit
7a9c4a03
authored
Jul 19, 2021
by
rprenger
Browse files
Removing bug possibilities and adding timing info
parent
29dd0a35
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
2 deletions
+9
-2
megatron/api_server.py
megatron/api_server.py
+1
-1
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+5
-0
tools/run_cli.py
tools/run_cli.py
+3
-1
No files found.
megatron/api_server.py
View file @
7a9c4a03
...
@@ -61,4 +61,4 @@ class MegatronServer(object):
...
@@ -61,4 +61,4 @@ class MegatronServer(object):
api
.
add_resource
(
MegatronGenerate
,
'/generate'
,
resource_class_args
=
[
model
])
api
.
add_resource
(
MegatronGenerate
,
'/generate'
,
resource_class_args
=
[
model
])
def
run
(
self
,
url
):
def
run
(
self
,
url
):
self
.
app
.
run
(
url
,
debug
=
False
)
self
.
app
.
run
(
url
,
threaded
=
False
,
debug
=
False
)
megatron/text_generation_utils.py
View file @
7a9c4a03
...
@@ -162,6 +162,9 @@ def synced_generate(model, context_length_tensor, context_tokens_tensor, max_len
...
@@ -162,6 +162,9 @@ def synced_generate(model, context_length_tensor, context_tokens_tensor, max_len
def
generate
(
model
,
sentences
=
None
,
max_len
=
0
):
def
generate
(
model
,
sentences
=
None
,
max_len
=
0
):
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
context_tokens_tensor
,
context_length_tensor
=
tokenize_batch
(
sentences
)
context_tokens_tensor
,
context_length_tensor
=
tokenize_batch
(
sentences
)
c
=
context_length_tensor
[
0
]
b
=
context_tokens_tensor
.
size
(
0
)
start
=
time
.
time
()
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
max_len
)
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
max_len
)
else
:
else
:
context_length_tensor
,
context_tokens_tensor
,
max_len
=
receive_generate_info
()
context_length_tensor
,
context_tokens_tensor
,
max_len
=
receive_generate_info
()
...
@@ -176,6 +179,8 @@ def generate(model, sentences=None, max_len=0):
...
@@ -176,6 +179,8 @@ def generate(model, sentences=None, max_len=0):
for
i
in
range
(
decode_tokens
.
size
(
0
)):
for
i
in
range
(
decode_tokens
.
size
(
0
)):
decode_token
=
decode_tokens
[
i
,:].
cpu
().
numpy
().
tolist
()
decode_token
=
decode_tokens
[
i
,:].
cpu
().
numpy
().
tolist
()
resp_sentences
.
append
(
tokenizer
.
detokenize
(
decode_token
))
resp_sentences
.
append
(
tokenizer
.
detokenize
(
decode_token
))
end
=
time
.
time
()
print
(
str
(
b
)
+
","
+
str
(
c
)
+
","
+
str
(
decode_tokens
.
size
(
1
))
+
","
+
str
(
end
-
start
),
flush
=
True
)
return
resp_sentences
return
resp_sentences
def
switch
(
val1
,
val2
,
boolean
):
def
switch
(
val1
,
val2
,
boolean
):
...
...
tools/run_cli.py
View file @
7a9c4a03
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
json
import
json
import
sys
import
urllib2
import
urllib2
class
PutRequest
(
urllib2
.
Request
):
class
PutRequest
(
urllib2
.
Request
):
'''class to handling putting with urllib2'''
'''class to handling putting with urllib2'''
...
@@ -21,11 +22,12 @@ class PutRequest(urllib2.Request):
...
@@ -21,11 +22,12 @@ class PutRequest(urllib2.Request):
return
'PUT'
return
'PUT'
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
url
=
sys
.
argv
[
1
]
while
True
:
while
True
:
sentence
=
raw_input
(
"Enter prompt: "
)
sentence
=
raw_input
(
"Enter prompt: "
)
max_len
=
int
(
input
(
"Enter number tokens output: "
))
max_len
=
int
(
input
(
"Enter number tokens output: "
))
data
=
json
.
dumps
({
"sentences"
:
[
sentence
],
"max_len"
:
max_len
})
data
=
json
.
dumps
({
"sentences"
:
[
sentence
],
"max_len"
:
max_len
})
req
=
PutRequest
(
"http://sc-sdgx2-484:5000/generate"
,
data
,
{
'Content-Type'
:
'application/json'
})
req
=
PutRequest
(
url
,
data
,
{
'Content-Type'
:
'application/json'
})
response
=
urllib2
.
urlopen
(
req
)
response
=
urllib2
.
urlopen
(
req
)
resp_sentences
=
json
.
load
(
response
)
resp_sentences
=
json
.
load
(
response
)
print
(
"Megatron Response: "
)
print
(
"Megatron Response: "
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment