"...git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "f43ec2ddfa01bb6d69328c00f9f3893cfc74f364"
Commit 2d064933 authored by carlushuang's avatar carlushuang
Browse files

support change batch size in xdnn desc

parent c32f0c3b
...@@ -240,13 +240,13 @@ int main(int argc, char* argv[]) ...@@ -240,13 +240,13 @@ int main(int argc, char* argv[])
if(ck::getenv_int("CK_USE_XDNN_DESC", 0) == 1) if(ck::getenv_int("CK_USE_XDNN_DESC", 0) == 1)
{ {
assert(argc == 4); assert(argc == 4 || argc == 5);
data_type = std::stoi(argv[1]); data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
ck::desc_t xdnn_desc; ck::desc_t xdnn_desc;
if(str2desc(&xdnn_desc, argv[3]) == XDNN_OK) if(str2desc(&xdnn_desc, argv[argc - 1]) == XDNN_OK)
{ {
N = xdnn_desc.mb; N = xdnn_desc.mb;
K = xdnn_desc.oc; K = xdnn_desc.oc;
...@@ -266,8 +266,11 @@ int main(int argc, char* argv[]) ...@@ -266,8 +266,11 @@ int main(int argc, char* argv[])
} }
else else
{ {
printf("fail to parse xdnn arg:%s\n", argv[3]); printf("fail to parse xdnn arg:%s\n", argv[argc - 1]);
exit(1);
} }
if(argc == 5)
N = std::stoi(argv[3]);
} }
else else
{ {
......
...@@ -276,13 +276,13 @@ int main(int argc, char* argv[]) ...@@ -276,13 +276,13 @@ int main(int argc, char* argv[])
if(ck::getenv_int("CK_USE_XDNN_DESC", 0) == 1) if(ck::getenv_int("CK_USE_XDNN_DESC", 0) == 1)
{ {
assert(argc == 4); assert(argc == 4 || argc == 5);
data_type = std::stoi(argv[1]); data_type = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
ck::desc_t xdnn_desc; ck::desc_t xdnn_desc;
if(str2desc(&xdnn_desc, argv[3]) == XDNN_OK) if(str2desc(&xdnn_desc, argv[argc - 1]) == XDNN_OK)
{ {
N = xdnn_desc.mb; N = xdnn_desc.mb;
K = xdnn_desc.oc; K = xdnn_desc.oc;
...@@ -302,8 +302,11 @@ int main(int argc, char* argv[]) ...@@ -302,8 +302,11 @@ int main(int argc, char* argv[])
} }
else else
{ {
printf("fail to parse xdnn arg:%s\n", argv[3]); printf("fail to parse xdnn arg:%s\n", argv[argc - 1]);
exit(1);
} }
if(argc == 5)
N = std::stoi(argv[3]);
} }
else else
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment