Unverified Commit 079bf3cb authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
parent 3c6b831c
[flake8]
ignore =
;W503 line break before binary operator
W503,
;E203 whitespace before ':'
E203,
; exclude file
exclude =
.tox,
.git,
__pycache__,
build,
dist,
*.pyc,
*.egg-info,
.cache,
.eggs
max-line-length = 120
per-file-ignores = __init__.py:F401
...@@ -22,13 +22,13 @@ def compare_dirs(dir1, dir2): ...@@ -22,13 +22,13 @@ def compare_dirs(dir1, dir2):
# If the corresponding item doesn't exist in the second directory, the directories are different # If the corresponding item doesn't exist in the second directory, the directories are different
if not os.path.exists(item_path2): if not os.path.exists(item_path2):
print(f'Found mismatch: {item_path1}, {item_path2}') print(f"Found mismatch: {item_path1}, {item_path2}")
return False return False
# If the corresponding item is a directory, we compare the two directories recursively # If the corresponding item is a directory, we compare the two directories recursively
if os.path.isdir(item_path1) and os.path.isdir(item_path2): if os.path.isdir(item_path1) and os.path.isdir(item_path2):
if not compare_dirs(item_path1, item_path2): if not compare_dirs(item_path1, item_path2):
print(f'Found mismatch: {item_path1}, {item_path2}') print(f"Found mismatch: {item_path1}, {item_path2}")
return False return False
# both are files # both are files
...@@ -37,16 +37,16 @@ def compare_dirs(dir1, dir2): ...@@ -37,16 +37,16 @@ def compare_dirs(dir1, dir2):
# If the corresponding item is not a file or a directory, the directories are different # If the corresponding item is not a file or a directory, the directories are different
else: else:
print(f'Found mismatch: {item_path1}, {item_path2}') print(f"Found mismatch: {item_path1}, {item_path2}")
return False return False
# If all items are the same, the directories are the same # If all items are the same, the directories are the same
return True return True
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-d', '--directory', help="The directory where the multi-language source files are kept.") parser.add_argument("-d", "--directory", help="The directory where the multi-language source files are kept.")
args = parser.parse_args() args = parser.parse_args()
i18n_folders = os.listdir(args.directory) i18n_folders = os.listdir(args.directory)
...@@ -56,7 +56,7 @@ if __name__ == '__main__': ...@@ -56,7 +56,7 @@ if __name__ == '__main__':
for i in range(1, len(i18n_folders)): for i in range(1, len(i18n_folders)):
dir1 = i18n_folders[0] dir1 = i18n_folders[0]
dir2 = i18n_folders[i] dir2 = i18n_folders[i]
print(f'comparing {dir1} vs {dir2}') print(f"comparing {dir1} vs {dir2}")
match = compare_dirs(i18n_folders[0], i18n_folders[i]) match = compare_dirs(i18n_folders[0], i18n_folders[i])
if not match: if not match:
......
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
def check_inputs(input_list): def check_inputs(input_list):
for path in input_list: for path in input_list:
real_path = os.path.join('examples', path) real_path = os.path.join("examples", path)
if not os.path.exists(real_path): if not os.path.exists(real_path):
return False return False
return True return True
...@@ -12,16 +12,16 @@ def check_inputs(input_list): ...@@ -12,16 +12,16 @@ def check_inputs(input_list):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-f', '--fileNameList', type=str, help="List of file names") parser.add_argument("-f", "--fileNameList", type=str, help="List of file names")
args = parser.parse_args() args = parser.parse_args()
name_list = args.fileNameList.split(",") name_list = args.fileNameList.split(",")
is_correct = check_inputs(name_list) is_correct = check_inputs(name_list)
if is_correct: if is_correct:
print('success') print("success")
else: else:
print('failure') print("failure")
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -17,21 +17,21 @@ def show_files(path, all_files): ...@@ -17,21 +17,21 @@ def show_files(path, all_files):
def join(input_list, sep=None): def join(input_list, sep=None):
return (sep or ' ').join(input_list) return (sep or " ").join(input_list)
def main(): def main():
contents = show_files('examples/', []) contents = show_files("examples/", [])
all_loc = [] all_loc = []
for file_loc in contents: for file_loc in contents:
split_loc = file_loc.split('/') split_loc = file_loc.split("/")
# must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not. # must have two sub-folder levels after examples folder, such as examples/images/vit is acceptable, examples/images/README.md is not, examples/requirements.txt is not.
if len(split_loc) >= 4: if len(split_loc) >= 4:
re_loc = '/'.join(split_loc[1:3]) re_loc = "/".join(split_loc[1:3])
if re_loc not in all_loc: if re_loc not in all_loc:
all_loc.append(re_loc) all_loc.append(re_loc)
print(all_loc) print(all_loc)
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -3,7 +3,7 @@ import argparse ...@@ -3,7 +3,7 @@ import argparse
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-f', '--fileNameList', type=str, help="The list of changed files") parser.add_argument("-f", "--fileNameList", type=str, help="The list of changed files")
args = parser.parse_args() args = parser.parse_args()
name_list = args.fileNameList.split(":") name_list = args.fileNameList.split(":")
folder_need_check = set() folder_need_check = set()
...@@ -15,10 +15,10 @@ def main(): ...@@ -15,10 +15,10 @@ def main():
# - application # - application
# - file # - file
if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4: if loc.split("/")[0] == "examples" and len(loc.split("/")) >= 4:
folder_need_check.add('/'.join(loc.split("/")[1:3])) folder_need_check.add("/".join(loc.split("/")[1:3]))
# Output the result using print. Then the shell can get the values. # Output the result using print. Then the shell can get the values.
print(list(folder_need_check)) print(list(folder_need_check))
if __name__ == '__main__': if __name__ == "__main__":
main() main()
...@@ -74,16 +74,16 @@ def get_organization_repositories(github_token, organization_name) -> List[str]: ...@@ -74,16 +74,16 @@ def get_organization_repositories(github_token, organization_name) -> List[str]:
# prepare header # prepare header
headers = { headers = {
'Authorization': f'Bearer {github_token}', "Authorization": f"Bearer {github_token}",
'Accept': 'application/vnd.github+json', "Accept": "application/vnd.github+json",
'X-GitHub-Api-Version': '2022-11-28' "X-GitHub-Api-Version": "2022-11-28",
} }
res = requests.get(url, headers=headers).json() res = requests.get(url, headers=headers).json()
repo_list = [] repo_list = []
for item in res: for item in res:
repo_list.append(item['name']) repo_list.append(item["name"])
return repo_list return repo_list
...@@ -97,9 +97,9 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name: ...@@ -97,9 +97,9 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name:
""" """
# prepare header # prepare header
headers = { headers = {
'Authorization': f'Bearer {github_token}', "Authorization": f"Bearer {github_token}",
'Accept': 'application/vnd.github+json', "Accept": "application/vnd.github+json",
'X-GitHub-Api-Version': '2022-11-28' "X-GitHub-Api-Version": "2022-11-28",
} }
user_engagement_count = {} user_engagement_count = {}
...@@ -107,28 +107,28 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name: ...@@ -107,28 +107,28 @@ def get_issue_pull_request_comments(github_token: str, org_name: str, repo_name:
# do pagination to the API # do pagination to the API
page = 1 page = 1
while True: while True:
comment_api = f'https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}' comment_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/comments?since={since}&page={page}"
comment_response = requests.get(comment_api, headers=headers).json() comment_response = requests.get(comment_api, headers=headers).json()
if len(comment_response) == 0: if len(comment_response) == 0:
break break
else: else:
for item in comment_response: for item in comment_response:
comment_author_relationship = item['author_association'] comment_author_relationship = item["author_association"]
if comment_author_relationship != 'MEMBER': if comment_author_relationship != "MEMBER":
# if the comment is not made by our member # if the comment is not made by our member
# we don't count this comment towards user engagement # we don't count this comment towards user engagement
continue continue
issue_id = item['issue_url'].split('/')[-1] issue_id = item["issue_url"].split("/")[-1]
issue_api = f'https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}' issue_api = f"https://api.github.com/repos/{org_name}/{repo_name}/issues/{issue_id}"
issue_response = requests.get(issue_api, headers=headers).json() issue_response = requests.get(issue_api, headers=headers).json()
issue_author_relationship = issue_response['author_association'] issue_author_relationship = issue_response["author_association"]
if issue_author_relationship != 'MEMBER': if issue_author_relationship != "MEMBER":
# this means that the issue/PR is not created by our own people # this means that the issue/PR is not created by our own people
# any comments in this issue/PR by our member will be counted towards the leaderboard # any comments in this issue/PR by our member will be counted towards the leaderboard
member_name = item['user']['login'] member_name = item["user"]["login"]
if member_name in user_engagement_count: if member_name in user_engagement_count:
user_engagement_count[member_name] += 1 user_engagement_count[member_name] += 1
...@@ -153,7 +153,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si ...@@ -153,7 +153,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
if cursor is None: if cursor is None:
offset_str = "" offset_str = ""
else: else:
offset_str = f", after: \"{cursor}\"" offset_str = f', after: "{cursor}"'
query = f""" query = f"""
{{ {{
repository(owner: "{org_name}", name: "{repo_name}"){{ repository(owner: "{org_name}", name: "{repo_name}"){{
...@@ -182,7 +182,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si ...@@ -182,7 +182,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
if cursor is None: if cursor is None:
offset_str = "" offset_str = ""
else: else:
offset_str = f", before: \"{cursor}\"" offset_str = f', before: "{cursor}"'
query = f""" query = f"""
{{ {{
repository(owner: "{org_name}", name: "{repo_name}"){{ repository(owner: "{org_name}", name: "{repo_name}"){{
...@@ -220,8 +220,8 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si ...@@ -220,8 +220,8 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
# a utility function to make call to Github GraphQL API # a utility function to make call to Github GraphQL API
def _call_graphql_api(query): def _call_graphql_api(query):
headers = {"Authorization": f"Bearer {github_token}"} headers = {"Authorization": f"Bearer {github_token}"}
json_data = {'query': query} json_data = {"query": query}
response = requests.post('https://api.github.com/graphql', json=json_data, headers=headers) response = requests.post("https://api.github.com/graphql", json=json_data, headers=headers)
data = response.json() data = response.json()
return data return data
...@@ -234,21 +234,21 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si ...@@ -234,21 +234,21 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
data = _call_graphql_api(query) data = _call_graphql_api(query)
found_discussion_out_of_time_range = False found_discussion_out_of_time_range = False
edges = data['data']['repository']['discussions']['edges'] edges = data["data"]["repository"]["discussions"]["edges"]
if len(edges) == 0: if len(edges) == 0:
break break
else: else:
# keep the discussion whose author is not a member # keep the discussion whose author is not a member
for edge in edges: for edge in edges:
# print the discussion title # print the discussion title
discussion = edge['node'] discussion = edge["node"]
discussion_updated_at = str2datetime(discussion['updatedAt']) discussion_updated_at = str2datetime(discussion["updatedAt"])
# check if the updatedAt is within the last 7 days # check if the updatedAt is within the last 7 days
# if yes, add it to discussion_numbers # if yes, add it to discussion_numbers
if discussion_updated_at > since: if discussion_updated_at > since:
if discussion['authorAssociation'] != 'MEMBER': if discussion["authorAssociation"] != "MEMBER":
discussion_numbers.append(discussion['number']) discussion_numbers.append(discussion["number"])
else: else:
found_discussion_out_of_time_range = True found_discussion_out_of_time_range = True
...@@ -256,7 +256,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si ...@@ -256,7 +256,7 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
break break
else: else:
# update cursor # update cursor
cursor = edges[-1]['cursor'] cursor = edges[-1]["cursor"]
# get the discussion comments and replies made by our member # get the discussion comments and replies made by our member
user_engagement_count = {} user_engagement_count = {}
...@@ -269,42 +269,42 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si ...@@ -269,42 +269,42 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
data = _call_graphql_api(query) data = _call_graphql_api(query)
# get the comments # get the comments
edges = data['data']['repository']['discussion']['comments']['edges'] edges = data["data"]["repository"]["discussion"]["comments"]["edges"]
# update the cursor # update the cursor
if len(edges) == 0: if len(edges) == 0:
break break
else: else:
# update cursor for pagination # update cursor for pagination
cursor = edges[-1]['cursor'] cursor = edges[-1]["cursor"]
for edge in edges: for edge in edges:
comment = edge['node'] comment = edge["node"]
if comment['authorAssociation'] == 'MEMBER': if comment["authorAssociation"] == "MEMBER":
# check if the updatedAt is within the last 7 days # check if the updatedAt is within the last 7 days
# if yes, add it to user_engagement_count # if yes, add it to user_engagement_count
comment_updated_at = datetime.strptime(comment['updatedAt'], "%Y-%m-%dT%H:%M:%SZ") comment_updated_at = datetime.strptime(comment["updatedAt"], "%Y-%m-%dT%H:%M:%SZ")
if comment_updated_at > since: if comment_updated_at > since:
member_name = comment['author']['login'] member_name = comment["author"]["login"]
if member_name in user_engagement_count: if member_name in user_engagement_count:
user_engagement_count[member_name] += 1 user_engagement_count[member_name] += 1
else: else:
user_engagement_count[member_name] = 1 user_engagement_count[member_name] = 1
# get the replies # get the replies
reply_edges = comment['replies']['edges'] reply_edges = comment["replies"]["edges"]
if len(reply_edges) == 0: if len(reply_edges) == 0:
continue continue
else: else:
for reply_edge in reply_edges: for reply_edge in reply_edges:
reply = reply_edge['node'] reply = reply_edge["node"]
if reply['authorAssociation'] == 'MEMBER': if reply["authorAssociation"] == "MEMBER":
# check if the updatedAt is within the last 7 days # check if the updatedAt is within the last 7 days
# if yes, add it to discussion_numbers # if yes, add it to discussion_numbers
reply_updated_at = datetime.strptime(reply['updatedAt'], "%Y-%m-%dT%H:%M:%SZ") reply_updated_at = datetime.strptime(reply["updatedAt"], "%Y-%m-%dT%H:%M:%SZ")
if reply_updated_at > since: if reply_updated_at > since:
member_name = reply['author']['login'] member_name = reply["author"]["login"]
if member_name in user_engagement_count: if member_name in user_engagement_count:
user_engagement_count[member_name] += 1 user_engagement_count[member_name] += 1
else: else:
...@@ -312,7 +312,9 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si ...@@ -312,7 +312,9 @@ def get_discussion_comments(github_token: str, org_name: str, repo_name: str, si
return user_engagement_count return user_engagement_count
def generate_user_engagement_leaderboard_image(github_token: str, org_name: str, repo_list: List[str], output_path: str) -> bool: def generate_user_engagement_leaderboard_image(
github_token: str, org_name: str, repo_list: List[str], output_path: str
) -> bool:
""" """
Generate the user engagement leaderboard image for stats within the last 7 days Generate the user engagement leaderboard image for stats within the last 7 days
...@@ -335,11 +337,14 @@ def generate_user_engagement_leaderboard_image(github_token: str, org_name: str, ...@@ -335,11 +337,14 @@ def generate_user_engagement_leaderboard_image(github_token: str, org_name: str,
else: else:
total_engagement_count[name] = count total_engagement_count[name] = count
for repo_name in repo_list: for repo_name in repo_list:
print(f"Fetching user engagement count for {repo_name}/{repo_name}") print(f"Fetching user engagement count for {repo_name}/{repo_name}")
issue_pr_engagement_count = get_issue_pull_request_comments(github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str) issue_pr_engagement_count = get_issue_pull_request_comments(
discussion_engagement_count = get_discussion_comments(github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime) github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime_str
)
discussion_engagement_count = get_discussion_comments(
github_token=github_token, org_name=org_name, repo_name=repo_name, since=start_datetime
)
# update the total engagement count # update the total engagement count
_update_count(issue_pr_engagement_count) _update_count(issue_pr_engagement_count)
...@@ -363,7 +368,7 @@ def generate_user_engagement_leaderboard_image(github_token: str, org_name: str, ...@@ -363,7 +368,7 @@ def generate_user_engagement_leaderboard_image(github_token: str, org_name: str,
# plot the leaderboard # plot the leaderboard
xlabel = f"Number of Comments made (since {start_datetime_str})" xlabel = f"Number of Comments made (since {start_datetime_str})"
ylabel = "Member" ylabel = "Member"
title = 'Active User Engagement Leaderboard' title = "Active User Engagement Leaderboard"
plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path) plot_bar_chart(x, y, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
return True return True
else: else:
...@@ -380,16 +385,16 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou ...@@ -380,16 +385,16 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
""" """
# request to the Github API to get the users who have contributed in the last 7 days # request to the Github API to get the users who have contributed in the last 7 days
headers = { headers = {
'Authorization': f'Bearer {github_token}', "Authorization": f"Bearer {github_token}",
'Accept': 'application/vnd.github+json', "Accept": "application/vnd.github+json",
'X-GitHub-Api-Version': '2022-11-28' "X-GitHub-Api-Version": "2022-11-28",
} }
counter = Counter() counter = Counter()
start_datetime = get_utc_time_one_week_ago() start_datetime = get_utc_time_one_week_ago()
def _get_url(org_name, repo_name, page): def _get_url(org_name, repo_name, page):
return f'https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed' return f"https://api.github.com/repos/{org_name}/{repo_name}/pulls?per_page=50&page={page}&state=closed"
def _iterate_by_page(org_name, repo_name): def _iterate_by_page(org_name, repo_name):
page = 1 page = 1
...@@ -415,8 +420,8 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou ...@@ -415,8 +420,8 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
# count the pull request and author from response # count the pull request and author from response
for pr_data in response: for pr_data in response:
merged_at = pr_data['merged_at'] merged_at = pr_data["merged_at"]
author = pr_data['user']['login'] author = pr_data["user"]["login"]
if merged_at is None: if merged_at is None:
continue continue
...@@ -439,7 +444,7 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou ...@@ -439,7 +444,7 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
_iterate_by_page(org_name, repo_name) _iterate_by_page(org_name, repo_name)
# convert unix timestamp to Beijing datetime # convert unix timestamp to Beijing datetime
bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone('Asia/Shanghai')) bj_start_datetime = datetime.fromtimestamp(start_datetime.timestamp(), tz=pytz.timezone("Asia/Shanghai"))
bj_start_datetime_str = datetime2str(bj_start_datetime) bj_start_datetime_str = datetime2str(bj_start_datetime)
contribution_list = counter.to_sorted_list() contribution_list = counter.to_sorted_list()
...@@ -452,7 +457,7 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou ...@@ -452,7 +457,7 @@ def generate_contributor_leaderboard_image(github_token, org_name, repo_list, ou
if len(author_list) > 0: if len(author_list) > 0:
xlabel = f"Number of Pull Requests (since {bj_start_datetime_str})" xlabel = f"Number of Pull Requests (since {bj_start_datetime_str})"
ylabel = "Contributor" ylabel = "Contributor"
title = 'Active Contributor Leaderboard' title = "Active Contributor Leaderboard"
plot_bar_chart(num_commit_list, author_list, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path) plot_bar_chart(num_commit_list, author_list, xlabel=xlabel, ylabel=ylabel, title=title, output_path=output_path)
return True return True
else: else:
...@@ -468,14 +473,14 @@ def upload_image_to_lark(lark_tenant_token: str, image_path: str) -> str: ...@@ -468,14 +473,14 @@ def upload_image_to_lark(lark_tenant_token: str, image_path: str) -> str:
image_path (str): the path to the image to be uploaded image_path (str): the path to the image to be uploaded
""" """
url = "https://open.feishu.cn/open-apis/im/v1/images" url = "https://open.feishu.cn/open-apis/im/v1/images"
form = {'image_type': 'message', 'image': (open(image_path, 'rb'))} # 需要替换具体的path form = {"image_type": "message", "image": (open(image_path, "rb"))} # 需要替换具体的path
multi_form = MultipartEncoder(form) multi_form = MultipartEncoder(form)
headers = { headers = {
'Authorization': f'Bearer {lark_tenant_token}', ## 获取tenant_access_token, 需要替换为实际的token "Authorization": f"Bearer {lark_tenant_token}", ## 获取tenant_access_token, 需要替换为实际的token
} }
headers['Content-Type'] = multi_form.content_type headers["Content-Type"] = multi_form.content_type
response = requests.request("POST", url, headers=headers, data=multi_form).json() response = requests.request("POST", url, headers=headers, data=multi_form).json()
return response['data']['image_key'] return response["data"]["image_key"]
def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str: def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str:
...@@ -486,10 +491,10 @@ def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str: ...@@ -486,10 +491,10 @@ def generate_lark_tenant_access_token(app_id: str, app_secret: str) -> str:
app_id (str): Lark app id app_id (str): Lark app id
app_secret (str): Lark app secret app_secret (str): Lark app secret
""" """
url = 'https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal' url = "https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal"
data = {'app_id': app_id, 'app_secret': app_secret} data = {"app_id": app_id, "app_secret": app_secret}
response = requests.post(url, json=data).json() response = requests.post(url, json=data).json()
return response['tenant_access_token'] return response["tenant_access_token"]
def send_image_to_lark(image_key: str, webhook_url: str) -> None: def send_image_to_lark(image_key: str, webhook_url: str) -> None:
...@@ -516,10 +521,10 @@ def send_message_to_lark(message: str, webhook_url: str): ...@@ -516,10 +521,10 @@ def send_message_to_lark(message: str, webhook_url: str):
requests.post(webhook_url, json=data) requests.post(webhook_url, json=data)
if __name__ == '__main__': if __name__ == "__main__":
GITHUB_TOKEN = os.environ['GITHUB_TOKEN'] GITHUB_TOKEN = os.environ["GITHUB_TOKEN"]
CONTRIBUTOR_IMAGE_PATH = 'contributor_leaderboard.png' CONTRIBUTOR_IMAGE_PATH = "contributor_leaderboard.png"
USER_ENGAGEMENT_IMAGE_PATH = 'engagement_leaderboard.png' USER_ENGAGEMENT_IMAGE_PATH = "engagement_leaderboard.png"
ORG_NAME = "hpcaitech" ORG_NAME = "hpcaitech"
# get all open source repositories # get all open source repositories
...@@ -527,17 +532,19 @@ if __name__ == '__main__': ...@@ -527,17 +532,19 @@ if __name__ == '__main__':
# generate images # generate images
contrib_success = generate_contributor_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, CONTRIBUTOR_IMAGE_PATH) contrib_success = generate_contributor_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, CONTRIBUTOR_IMAGE_PATH)
engagement_success = generate_user_engagement_leaderboard_image(GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH) engagement_success = generate_user_engagement_leaderboard_image(
GITHUB_TOKEN, ORG_NAME, REPO_LIST, USER_ENGAGEMENT_IMAGE_PATH
)
# upload images # upload images
APP_ID = os.environ['LARK_APP_ID'] APP_ID = os.environ["LARK_APP_ID"]
APP_SECRET = os.environ['LARK_APP_SECRET'] APP_SECRET = os.environ["LARK_APP_SECRET"]
LARK_TENANT_TOKEN = generate_lark_tenant_access_token(app_id=APP_ID, app_secret=APP_SECRET) LARK_TENANT_TOKEN = generate_lark_tenant_access_token(app_id=APP_ID, app_secret=APP_SECRET)
contributor_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, CONTRIBUTOR_IMAGE_PATH) contributor_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, CONTRIBUTOR_IMAGE_PATH)
user_engagement_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, USER_ENGAGEMENT_IMAGE_PATH) user_engagement_image_key = upload_image_to_lark(LARK_TENANT_TOKEN, USER_ENGAGEMENT_IMAGE_PATH)
# send message to lark # send message to lark
LARK_WEBHOOK_URL = os.environ['LARK_WEBHOOK_URL'] LARK_WEBHOOK_URL = os.environ["LARK_WEBHOOK_URL"]
message = """本周的社区榜单出炉啦! message = """本周的社区榜单出炉啦!
1. 开发贡献者榜单 1. 开发贡献者榜单
2. 用户互动榜单 2. 用户互动榜单
......
...@@ -7,27 +7,27 @@ import re ...@@ -7,27 +7,27 @@ import re
import requests import requests
COMMIT_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/commits' COMMIT_API = "https://api.github.com/repos/hpcaitech/ColossalAI/commits"
TAGS_API = 'https://api.github.com/repos/hpcaitech/ColossalAI/tags' TAGS_API = "https://api.github.com/repos/hpcaitech/ColossalAI/tags"
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--out', type=str, help='output path for the release draft', required=True) parser.add_argument("--out", type=str, help="output path for the release draft", required=True)
parser.add_argument('--version', type=str, help='current version to release', required=True) parser.add_argument("--version", type=str, help="current version to release", required=True)
return parser.parse_args() return parser.parse_args()
def get_latest_tag_commit(headers=None): def get_latest_tag_commit(headers=None):
res = requests.get(url=TAGS_API, headers=headers) res = requests.get(url=TAGS_API, headers=headers)
data = res.json() data = res.json()
commit_hash = data[0]['commit']['sha'] commit_hash = data[0]["commit"]["sha"]
version = data[0]['name'] version = data[0]["name"]
return commit_hash, version return commit_hash, version
def get_commit_info(commit_hash, headers=None): def get_commit_info(commit_hash, headers=None):
api = f'{COMMIT_API}/{commit_hash}' api = f"{COMMIT_API}/{commit_hash}"
res = requests.get(url=api, headers=headers) res = requests.get(url=api, headers=headers)
return res.json() return res.json()
...@@ -37,7 +37,7 @@ def get_all_commit_info(since, headers=None): ...@@ -37,7 +37,7 @@ def get_all_commit_info(since, headers=None):
results = [] results = []
while True: while True:
api = f'{COMMIT_API}?since={since}&per_page=100&page={page}' api = f"{COMMIT_API}?since={since}&per_page=100&page={page}"
resp = requests.get(url=api, headers=headers) resp = requests.get(url=api, headers=headers)
data = resp.json() data = resp.json()
...@@ -53,21 +53,21 @@ def get_all_commit_info(since, headers=None): ...@@ -53,21 +53,21 @@ def get_all_commit_info(since, headers=None):
def collate_release_info(commit_info_list): def collate_release_info(commit_info_list):
results = dict() results = dict()
pattern = pattern = r'\[.*\]' pattern = pattern = r"\[.*\]"
for commit_info in commit_info_list: for commit_info in commit_info_list:
author = commit_info['commit']['author']['name'] author = commit_info["commit"]["author"]["name"]
try: try:
author_url = commit_info['author']['url'] author_url = commit_info["author"]["url"]
except: except:
# author can be None # author can be None
author_url = None author_url = None
msg = commit_info['commit']['message'] msg = commit_info["commit"]["message"]
match = re.search(pattern, msg) match = re.search(pattern, msg)
if match: if match:
tag = match.group().lstrip('[').rstrip(']').capitalize() tag = match.group().lstrip("[").rstrip("]").capitalize()
if tag not in results: if tag not in results:
results[tag] = [] results[tag] = []
results[tag].append((msg, author, author_url)) results[tag].append((msg, author, author_url))
...@@ -89,32 +89,33 @@ def generate_release_post_markdown(current_version, last_version, release_info): ...@@ -89,32 +89,33 @@ def generate_release_post_markdown(current_version, last_version, release_info):
for msg, author, author_url in v: for msg, author, author_url in v:
# only keep the first line # only keep the first line
msg = msg.split('\n')[0] msg = msg.split("\n")[0]
if author_url: if author_url:
item = f'{msg} by [{author}]({author_url})\n' item = f"{msg} by [{author}]({author_url})\n"
else: else:
item = f'{msg} by {author}\n' item = f"{msg} by {author}\n"
text.append(f'- {item}') text.append(f"- {item}")
text.append('\n') text.append("\n")
# add full change log # add full change log
text.append( text.append(
f'**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}') f"**Full Changelog**: https://github.com/hpcaitech/ColossalAI/compare/{current_version}...{last_version}"
)
return text return text
if __name__ == '__main__': if __name__ == "__main__":
args = parse_args() args = parse_args()
token = os.environ['GITHUB_API_TOKEN'] token = os.environ["GITHUB_API_TOKEN"]
headers = {'Authorization': token} headers = {"Authorization": token}
# get previous release tag # get previous release tag
last_release_commit, last_version = get_latest_tag_commit(headers) last_release_commit, last_version = get_latest_tag_commit(headers)
last_release_commit_info = get_commit_info(last_release_commit, headers=headers) last_release_commit_info = get_commit_info(last_release_commit, headers=headers)
last_release_date = last_release_commit_info['commit']['author']['date'] last_release_date = last_release_commit_info["commit"]["author"]["date"]
# get the commits since last release # get the commits since last release
commit_info = get_all_commit_info(since=last_release_date, headers=headers) commit_info = get_all_commit_info(since=last_release_date, headers=headers)
...@@ -125,6 +126,6 @@ if __name__ == '__main__': ...@@ -125,6 +126,6 @@ if __name__ == '__main__':
markdown_text = generate_release_post_markdown(args.version, last_version, release_info) markdown_text = generate_release_post_markdown(args.version, last_version, release_info)
# write into a file # write into a file
with open(args.out, 'w') as f: with open(args.out, "w") as f:
for line in markdown_text: for line in markdown_text:
f.write(line) f.write(line)
...@@ -5,8 +5,8 @@ import requests ...@@ -5,8 +5,8 @@ import requests
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-m', '--message', type=str) parser.add_argument("-m", "--message", type=str)
parser.add_argument('-u', '--url', type=str) parser.add_argument("-u", "--url", type=str)
return parser.parse_args() return parser.parse_args()
...@@ -15,6 +15,6 @@ def send_message_to_lark(message, webhook_url): ...@@ -15,6 +15,6 @@ def send_message_to_lark(message, webhook_url):
requests.post(webhook_url, json=data) requests.post(webhook_url, json=data)
if __name__ == '__main__': if __name__ == "__main__":
args = parse_args() args = parse_args()
send_message_to_lark(args.message, args.url) send_message_to_lark(args.message, args.url)
...@@ -3,3 +3,4 @@ line_length = 120 ...@@ -3,3 +3,4 @@ line_length = 120
multi_line_output=3 multi_line_output=3
include_trailing_comma = true include_trailing_comma = true
ignore_comments = true ignore_comments = true
profile = black
repos: repos:
- repo: https://github.com/PyCQA/autoflake
rev: v2.2.1
hooks:
- id: autoflake
name: autoflake (python)
args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
- repo: https://github.com/pycqa/isort - repo: https://github.com/pycqa/isort
rev: 5.12.0 rev: 5.12.0
hooks: hooks:
- id: isort - id: isort
name: sort all imports (python) name: sort all imports (python)
- repo: https://github.com/pre-commit/mirrors-yapf - repo: https://github.com/psf/black-pre-commit-mirror
rev: v0.32.0 rev: 23.9.1
hooks: hooks:
- id: yapf - id: black
name: yapf formatter name: black formatter
args: ['--style=.style.yapf', '--parallel', '--in-place'] args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: v13.0.1 rev: v13.0.1
hooks: hooks:
- id: clang-format - id: clang-format
name: clang formatter name: clang formatter
types_or: [c++, c]
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0 rev: v4.3.0
......
[style]
based_on_style = google
spaces_before_comment = 4
split_before_logical_operator = true
column_limit = 120
...@@ -27,7 +27,7 @@ def get_model_numel(model: nn.Module, strategy: Strategy) -> int: ...@@ -27,7 +27,7 @@ def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
def preprocess_batch(samples) -> dict: def preprocess_batch(samples) -> dict:
input_ids = torch.stack(samples) input_ids = torch.stack(samples)
attention_mask = torch.ones_like(input_ids, dtype=torch.long) attention_mask = torch.ones_like(input_ids, dtype=torch.long)
return {'input_ids': input_ids, 'attention_mask': attention_mask} return {"input_ids": input_ids, "attention_mask": attention_mask}
def print_rank_0(*args, **kwargs) -> None: def print_rank_0(*args, **kwargs) -> None:
...@@ -39,32 +39,32 @@ def print_model_numel(model_dict: dict) -> None: ...@@ -39,32 +39,32 @@ def print_model_numel(model_dict: dict) -> None:
B = 1024**3 B = 1024**3
M = 1024**2 M = 1024**2
K = 1024 K = 1024
outputs = '' outputs = ""
for name, numel in model_dict.items(): for name, numel in model_dict.items():
outputs += f'{name}: ' outputs += f"{name}: "
if numel >= B: if numel >= B:
outputs += f'{numel / B:.2f} B\n' outputs += f"{numel / B:.2f} B\n"
elif numel >= M: elif numel >= M:
outputs += f'{numel / M:.2f} M\n' outputs += f"{numel / M:.2f} M\n"
elif numel >= K: elif numel >= K:
outputs += f'{numel / K:.2f} K\n' outputs += f"{numel / K:.2f} K\n"
else: else:
outputs += f'{numel}\n' outputs += f"{numel}\n"
print_rank_0(outputs) print_rank_0(outputs)
def get_gpt_config(model_name: str) -> OPTConfig: def get_gpt_config(model_name: str) -> OPTConfig:
model_map = { model_map = {
'125m': OPTConfig.from_pretrained('facebook/opt-125m'), "125m": OPTConfig.from_pretrained("facebook/opt-125m"),
'350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16), "350m": OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
'700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20), "700m": OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
'1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'), "1.3b": OPTConfig.from_pretrained("facebook/opt-1.3b"),
'2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'), "2.7b": OPTConfig.from_pretrained("facebook/opt-2.7b"),
'3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32), "3.5b": OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
'5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32), "5.5b": OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
'6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'), "6.7b": OPTConfig.from_pretrained("facebook/opt-6.7b"),
'10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32), "10b": OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
'13b': OPTConfig.from_pretrained('facebook/opt-13b'), "13b": OPTConfig.from_pretrained("facebook/opt-13b"),
} }
try: try:
return model_map[model_name] return model_map[model_name]
...@@ -73,20 +73,20 @@ def get_gpt_config(model_name: str) -> OPTConfig: ...@@ -73,20 +73,20 @@ def get_gpt_config(model_name: str) -> OPTConfig:
def main(args): def main(args):
if args.strategy == 'ddp': if args.strategy == "ddp":
strategy = DDPStrategy() strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini': elif args.strategy == "colossalai_gemini":
strategy = GeminiStrategy(placement_policy='cuda', initial_scale=2**5) strategy = GeminiStrategy(placement_policy="cuda", initial_scale=2**5)
elif args.strategy == 'colossalai_gemini_cpu': elif args.strategy == "colossalai_gemini_cpu":
strategy = GeminiStrategy(placement_policy='cpu', initial_scale=2**5) strategy = GeminiStrategy(placement_policy="cpu", initial_scale=2**5)
elif args.strategy == 'colossalai_zero2': elif args.strategy == "colossalai_zero2":
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cuda') strategy = LowLevelZeroStrategy(stage=2, placement_policy="cuda")
elif args.strategy == 'colossalai_zero2_cpu': elif args.strategy == "colossalai_zero2_cpu":
strategy = LowLevelZeroStrategy(stage=2, placement_policy='cpu') strategy = LowLevelZeroStrategy(stage=2, placement_policy="cpu")
elif args.strategy == 'colossalai_zero1': elif args.strategy == "colossalai_zero1":
strategy = LowLevelZeroStrategy(stage=1, placement_policy='cuda') strategy = LowLevelZeroStrategy(stage=1, placement_policy="cuda")
elif args.strategy == 'colossalai_zero1_cpu': elif args.strategy == "colossalai_zero1_cpu":
strategy = LowLevelZeroStrategy(stage=1, placement_policy='cpu') strategy = LowLevelZeroStrategy(stage=1, placement_policy="cpu")
else: else:
raise ValueError(f'Unsupported strategy "{args.strategy}"') raise ValueError(f'Unsupported strategy "{args.strategy}"')
...@@ -103,45 +103,51 @@ def main(args): ...@@ -103,45 +103,51 @@ def main(args):
if args.use_kernels: if args.use_kernels:
from coati.kernels import convert_to_xformer_model from coati.kernels import convert_to_xformer_model
actor, critic, initial_model, reward_model = map(convert_to_xformer_model,
(actor, critic, initial_model, reward_model)) actor, critic, initial_model, reward_model = map(
convert_to_xformer_model, (actor, critic, initial_model, reward_model)
)
actor_numel = get_model_numel(actor, strategy) actor_numel = get_model_numel(actor, strategy)
critic_numel = get_model_numel(critic, strategy) critic_numel = get_model_numel(critic, strategy)
initial_model_numel = get_model_numel(initial_model, strategy) initial_model_numel = get_model_numel(initial_model, strategy)
reward_model_numel = get_model_numel(reward_model, strategy) reward_model_numel = get_model_numel(reward_model, strategy)
print_model_numel({ print_model_numel(
'Actor': actor_numel, {
'Critic': critic_numel, "Actor": actor_numel,
'Initial model': initial_model_numel, "Critic": critic_numel,
'Reward model': reward_model_numel "Initial model": initial_model_numel,
}) "Reward model": reward_model_numel,
performance_evaluator = PerformanceEvaluator(actor_numel, }
)
performance_evaluator = PerformanceEvaluator(
actor_numel,
critic_numel, critic_numel,
initial_model_numel, initial_model_numel,
reward_model_numel, reward_model_numel,
enable_grad_checkpoint=False, enable_grad_checkpoint=False,
ignore_episodes=1) ignore_episodes=1,
)
if args.strategy.startswith('colossalai'): if args.strategy.startswith("colossalai"):
actor_optim = HybridAdam(actor.parameters(), lr=5e-6) actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
critic_optim = HybridAdam(critic.parameters(), lr=5e-6) critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
else: else:
actor_optim = Adam(actor.parameters(), lr=5e-6) actor_optim = Adam(actor.parameters(), lr=5e-6)
critic_optim = Adam(critic.parameters(), lr=5e-6) critic_optim = Adam(critic.parameters(), lr=5e-6)
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim)) (actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device()) random_prompts = torch.randint(tokenizer.vocab_size, (1000, 256), device=torch.cuda.current_device())
dataloader = DataLoader(random_prompts, dataloader = DataLoader(
batch_size=args.experience_batch_size, random_prompts, batch_size=args.experience_batch_size, shuffle=True, collate_fn=preprocess_batch
shuffle=True, )
collate_fn=preprocess_batch)
trainer = PPOTrainer(strategy, trainer = PPOTrainer(
strategy,
actor, actor,
critic, critic,
reward_model, reward_model,
...@@ -158,35 +164,45 @@ def main(args): ...@@ -158,35 +164,45 @@ def main(args):
use_cache=True, use_cache=True,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator]) callbacks=[performance_evaluator],
)
trainer.fit(prompt_dataloader=dataloader, trainer.fit(
prompt_dataloader=dataloader,
pretrain_dataloader=None, pretrain_dataloader=None,
num_episodes=args.num_episodes, num_episodes=args.num_episodes,
num_update_steps=args.num_update_steps, num_update_steps=args.num_update_steps,
num_collect_steps=args.num_collect_steps) num_collect_steps=args.num_collect_steps,
)
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB') print_rank_0(f"Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB")
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', default='125m') parser.add_argument("--model", default="125m")
parser.add_argument('--critic_model', default='125m') parser.add_argument("--critic_model", default="125m")
parser.add_argument('--strategy', parser.add_argument(
"--strategy",
choices=[ choices=[
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2', "ddp",
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu' "colossalai_gemini",
"colossalai_gemini_cpu",
"colossalai_zero2",
"colossalai_zero2_cpu",
"colossalai_zero1",
"colossalai_zero1_cpu",
], ],
default='ddp') default="ddp",
parser.add_argument('--num_episodes', type=int, default=3) )
parser.add_argument('--num_collect_steps', type=int, default=8) parser.add_argument("--num_episodes", type=int, default=3)
parser.add_argument('--num_update_steps', type=int, default=1) parser.add_argument("--num_collect_steps", type=int, default=8)
parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument("--num_update_steps", type=int, default=1)
parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument('--lora_rank', type=int, default=0) parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument('--cuda_mem_frac', type=float, default=1.0) parser.add_argument("--lora_rank", type=int, default=0)
parser.add_argument('--offload_inference_models', action='store_true', default=False) parser.add_argument("--cuda_mem_frac", type=float, default=1.0)
parser.add_argument('--use_kernels', action='store_true', default=False) parser.add_argument("--offload_inference_models", action="store_true", default=False)
parser.add_argument("--use_kernels", action="store_true", default=False)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
...@@ -22,13 +22,13 @@ from transformers.modeling_utils import no_init_weights ...@@ -22,13 +22,13 @@ from transformers.modeling_utils import no_init_weights
def get_free_port(): def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0)) s.bind(("", 0))
return s.getsockname()[1] return s.getsockname()[1]
def get_local_ip(): def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80)) s.connect(("8.8.8.8", 80))
return s.getsockname()[0] return s.getsockname()[0]
...@@ -36,22 +36,25 @@ def main(args): ...@@ -36,22 +36,25 @@ def main(args):
master_addr = str(get_local_ip()) master_addr = str(get_local_ip())
# trainer_env_info # trainer_env_info
trainer_port = str(get_free_port()) trainer_port = str(get_free_port())
env_info_trainers = [{ env_info_trainers = [
'local_rank': '0', {
'rank': str(rank), "local_rank": "0",
'world_size': str(args.num_trainers), "rank": str(rank),
'master_port': trainer_port, "world_size": str(args.num_trainers),
'master_addr': master_addr "master_port": trainer_port,
} for rank in range(args.num_trainers)] "master_addr": master_addr,
}
for rank in range(args.num_trainers)
]
# maker_env_info # maker_env_info
maker_port = str(get_free_port()) maker_port = str(get_free_port())
env_info_maker = { env_info_maker = {
'local_rank': '0', "local_rank": "0",
'rank': '0', "rank": "0",
'world_size': '1', "world_size": "1",
'master_port': maker_port, "master_port": maker_port,
'master_addr': master_addr "master_addr": master_addr,
} }
# configure tokenizer # configure tokenizer
...@@ -63,21 +66,27 @@ def main(args): ...@@ -63,21 +66,27 @@ def main(args):
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain) critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda() critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.critic_model, reward_model = (
config=critic_cfg).requires_grad_(False).half().cuda() get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
if args.initial_model_quant_ckpt is not None and args.model == 'llama': )
if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model # quantize initial model
with low_resource_init(), no_init_weights(): with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg) initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, initial_model.model = (
args.quant_group_size).cuda().requires_grad_(False) llama_load_quant(
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
)
.cuda()
.requires_grad_(False)
)
else: else:
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model return actor, critic, reward_model, initial_model
# configure Experience Maker # configure Experience Maker
experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote( experience_holder_ref = ExperienceMakerHolder.options(name="maker0", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[f'trainer{i}' for i in range(args.num_trainers)], detached_trainer_name_list=[f"trainer{i}" for i in range(args.num_trainers)],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy), strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
model_fn=model_fn, model_fn=model_fn,
env_info=env_info_maker, env_info=env_info_maker,
...@@ -97,15 +106,18 @@ def main(args): ...@@ -97,15 +106,18 @@ def main(args):
def trainer_model_fn(): def trainer_model_fn():
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda() actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
critic = get_critic_from_args(args.critic_model, critic = (
config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda() get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
.half()
.cuda()
)
return actor, critic return actor, critic
# configure Trainer # configure Trainer
trainer_refs = [ trainer_refs = [
DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote( DetachedPPOTrainer.options(name=f"trainer{i}", num_gpus=1, max_concurrency=2).remote(
experience_maker_holder_name_list=[ experience_maker_holder_name_list=[
f'maker{x}' for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True) f"maker{x}" for x in get_receivers_per_sender(i, args.num_trainers, 1, allow_idle_sender=True)
], ],
strategy_fn=partial(get_strategy_from_args, args.trainer_strategy), strategy_fn=partial(get_strategy_from_args, args.trainer_strategy),
model_fn=trainer_model_fn, model_fn=trainer_model_fn,
...@@ -114,7 +126,8 @@ def main(args): ...@@ -114,7 +126,8 @@ def main(args):
buffer_limit=16, buffer_limit=16,
eval_performance=True, eval_performance=True,
debug=args.debug, debug=args.debug,
) for i, env_info_trainer in enumerate(env_info_trainers) )
for i, env_info_trainer in enumerate(env_info_trainers)
] ]
dataset_size = args.experience_batch_size * 4 dataset_size = args.experience_batch_size * 4
...@@ -122,7 +135,7 @@ def main(args): ...@@ -122,7 +135,7 @@ def main(args):
def data_gen_fn(): def data_gen_fn():
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device()) input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
attn_mask = torch.ones_like(input_ids) attn_mask = torch.ones_like(input_ids)
return {'input_ids': input_ids, 'attention_mask': attn_mask} return {"input_ids": input_ids, "attention_mask": attn_mask}
def build_dataloader(size): def build_dataloader(size):
dataset = [data_gen_fn() for _ in range(size)] dataset = [data_gen_fn() for _ in range(size)]
...@@ -138,8 +151,10 @@ def main(args): ...@@ -138,8 +151,10 @@ def main(args):
wait_tasks = [] wait_tasks = []
wait_tasks.append( wait_tasks.append(
experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size), experience_holder_ref.workingloop.remote(
num_steps=args.experience_steps)) partial(build_dataloader, dataset_size), num_steps=args.experience_steps
)
)
total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size) total_steps = args.experience_batch_size * args.experience_steps // (args.num_trainers * args.train_batch_size)
for trainer_ref in trainer_refs: for trainer_ref in trainer_refs:
...@@ -148,31 +163,30 @@ def main(args): ...@@ -148,31 +163,30 @@ def main(args):
ray.get(wait_tasks) ray.get(wait_tasks)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--num_trainers', type=int, default=1) parser.add_argument("--num_trainers", type=int, default=1)
parser.add_argument('--trainer_strategy', parser.add_argument(
choices=[ "--trainer_strategy",
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
'colossalai_zero2_cpu' default="ddp",
], )
default='ddp') parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
parser.add_argument('--maker_strategy', choices=['naive'], default='naive') parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument("--critic_pretrain", type=str, default=None)
parser.add_argument('--critic_pretrain', type=str, default=None) parser.add_argument("--experience_steps", type=int, default=4)
parser.add_argument('--experience_steps', type=int, default=4) parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument("--train_epochs", type=int, default=1)
parser.add_argument('--train_epochs', type=int, default=1) parser.add_argument("--update_steps", type=int, default=2)
parser.add_argument('--update_steps', type=int, default=2) parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None) parser.add_argument("--quant_bits", type=int, default=4)
parser.add_argument('--quant_bits', type=int, default=4) parser.add_argument("--quant_group_size", type=int, default=128)
parser.add_argument('--quant_group_size', type=int, default=128) parser.add_argument("--debug", action="store_true")
parser.add_argument('--debug', action='store_true')
args = parser.parse_args() args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args) main(args)
...@@ -22,13 +22,13 @@ from transformers.modeling_utils import no_init_weights ...@@ -22,13 +22,13 @@ from transformers.modeling_utils import no_init_weights
def get_free_port(): def get_free_port():
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0)) s.bind(("", 0))
return s.getsockname()[1] return s.getsockname()[1]
def get_local_ip(): def get_local_ip():
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80)) s.connect(("8.8.8.8", 80))
return s.getsockname()[0] return s.getsockname()[0]
...@@ -36,23 +36,29 @@ def main(args): ...@@ -36,23 +36,29 @@ def main(args):
master_addr = str(get_local_ip()) master_addr = str(get_local_ip())
# trainer_env_info # trainer_env_info
trainer_port = str(get_free_port()) trainer_port = str(get_free_port())
env_info_trainers = [{ env_info_trainers = [
'local_rank': '0', {
'rank': str(rank), "local_rank": "0",
'world_size': str(args.num_trainers), "rank": str(rank),
'master_port': trainer_port, "world_size": str(args.num_trainers),
'master_addr': master_addr "master_port": trainer_port,
} for rank in range(args.num_trainers)] "master_addr": master_addr,
}
for rank in range(args.num_trainers)
]
# maker_env_info # maker_env_info
maker_port = str(get_free_port()) maker_port = str(get_free_port())
env_info_makers = [{ env_info_makers = [
'local_rank': '0', {
'rank': str(rank), "local_rank": "0",
'world_size': str(args.num_makers), "rank": str(rank),
'master_port': maker_port, "world_size": str(args.num_makers),
'master_addr': master_addr "master_port": maker_port,
} for rank in range(args.num_makers)] "master_addr": master_addr,
}
for rank in range(args.num_makers)
]
# configure tokenizer # configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.pretrain) tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
...@@ -63,14 +69,20 @@ def main(args): ...@@ -63,14 +69,20 @@ def main(args):
critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain) critic_cfg = AutoConfig.from_pretrained(args.critic_pretrain)
actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() actor = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda() critic = get_critic_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
reward_model = get_reward_model_from_args(args.critic_model, reward_model = (
config=critic_cfg).requires_grad_(False).half().cuda() get_reward_model_from_args(args.critic_model, config=critic_cfg).requires_grad_(False).half().cuda()
if args.initial_model_quant_ckpt is not None and args.model == 'llama': )
if args.initial_model_quant_ckpt is not None and args.model == "llama":
# quantize initial model # quantize initial model
with low_resource_init(), no_init_weights(): with low_resource_init(), no_init_weights():
initial_model = get_actor_from_args(args.model, config=actor_cfg) initial_model = get_actor_from_args(args.model, config=actor_cfg)
initial_model.model = llama_load_quant(initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, initial_model.model = (
args.quant_group_size).cuda().requires_grad_(False) llama_load_quant(
initial_model.model, args.initial_model_quant_ckpt, args.quant_bits, args.quant_group_size
)
.cuda()
.requires_grad_(False)
)
else: else:
initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda() initial_model = get_actor_from_args(args.model, config=actor_cfg).requires_grad_(False).half().cuda()
return actor, critic, reward_model, initial_model return actor, critic, reward_model, initial_model
...@@ -79,7 +91,7 @@ def main(args): ...@@ -79,7 +91,7 @@ def main(args):
experience_holder_refs = [ experience_holder_refs = [
ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote( ExperienceMakerHolder.options(name=f"maker{i}", num_gpus=1, max_concurrency=2).remote(
detached_trainer_name_list=[ detached_trainer_name_list=[
f'trainer{x}' f"trainer{x}"
for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False) for x in get_receivers_per_sender(i, args.num_makers, args.num_trainers, allow_idle_sender=False)
], ],
strategy_fn=partial(get_strategy_from_args, args.maker_strategy), strategy_fn=partial(get_strategy_from_args, args.maker_strategy),
...@@ -103,8 +115,11 @@ def main(args): ...@@ -103,8 +115,11 @@ def main(args):
def trainer_model_fn(): def trainer_model_fn():
actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda() actor = get_actor_from_args(args.model, config=AutoConfig.from_pretrained(args.pretrain)).half().cuda()
critic = get_critic_from_args(args.critic_model, critic = (
config=AutoConfig.from_pretrained(args.critic_pretrain)).half().cuda() get_critic_from_args(args.critic_model, config=AutoConfig.from_pretrained(args.critic_pretrain))
.half()
.cuda()
)
return actor, critic return actor, critic
# configure Trainer # configure Trainer
...@@ -130,7 +145,7 @@ def main(args): ...@@ -130,7 +145,7 @@ def main(args):
def data_gen_fn(): def data_gen_fn():
input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device()) input_ids = torch.randint(tokenizer.vocab_size, (256,), device=torch.cuda.current_device())
attn_mask = torch.ones_like(input_ids) attn_mask = torch.ones_like(input_ids)
return {'input_ids': input_ids, 'attention_mask': attn_mask} return {"input_ids": input_ids, "attention_mask": attn_mask}
def build_dataloader(size): def build_dataloader(size):
dataset = [data_gen_fn() for _ in range(size)] dataset = [data_gen_fn() for _ in range(size)]
...@@ -147,43 +162,48 @@ def main(args): ...@@ -147,43 +162,48 @@ def main(args):
for experience_holder_ref in experience_holder_refs: for experience_holder_ref in experience_holder_refs:
wait_tasks.append( wait_tasks.append(
experience_holder_ref.workingloop.remote(partial(build_dataloader, dataset_size), experience_holder_ref.workingloop.remote(
num_steps=args.experience_steps)) partial(build_dataloader, dataset_size), num_steps=args.experience_steps
)
)
total_steps = args.experience_batch_size * args.experience_steps * \ total_steps = (
args.num_makers // (args.num_trainers * args.train_batch_size) args.experience_batch_size
* args.experience_steps
* args.num_makers
// (args.num_trainers * args.train_batch_size)
)
for trainer_ref in trainer_refs: for trainer_ref in trainer_refs:
wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs)) wait_tasks.append(trainer_ref.fit.remote(total_steps, args.update_steps, args.train_epochs))
ray.get(wait_tasks) ray.get(wait_tasks)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--num_makers', type=int, default=1) parser.add_argument("--num_makers", type=int, default=1)
parser.add_argument('--num_trainers', type=int, default=1) parser.add_argument("--num_trainers", type=int, default=1)
parser.add_argument('--trainer_strategy', parser.add_argument(
choices=[ "--trainer_strategy",
'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_gemini_cpu', choices=["ddp", "colossalai_gemini", "colossalai_zero2", "colossalai_gemini_cpu", "colossalai_zero2_cpu"],
'colossalai_zero2_cpu' default="ddp",
], )
default='ddp') parser.add_argument("--maker_strategy", choices=["naive"], default="naive")
parser.add_argument('--maker_strategy', choices=['naive'], default='naive') parser.add_argument("--model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument("--critic_model", default="gpt2", choices=["gpt2", "bloom", "opt", "llama"])
parser.add_argument('--critic_model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama']) parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument("--critic_pretrain", type=str, default=None)
parser.add_argument('--critic_pretrain', type=str, default=None) parser.add_argument("--experience_steps", type=int, default=4)
parser.add_argument('--experience_steps', type=int, default=4) parser.add_argument("--experience_batch_size", type=int, default=8)
parser.add_argument('--experience_batch_size', type=int, default=8) parser.add_argument("--train_epochs", type=int, default=1)
parser.add_argument('--train_epochs', type=int, default=1) parser.add_argument("--update_steps", type=int, default=2)
parser.add_argument('--update_steps', type=int, default=2) parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument("--initial_model_quant_ckpt", type=str, default=None)
parser.add_argument('--initial_model_quant_ckpt', type=str, default=None) parser.add_argument("--quant_bits", type=int, default=4)
parser.add_argument('--quant_bits', type=int, default=4) parser.add_argument("--quant_group_size", type=int, default=128)
parser.add_argument('--quant_group_size', type=int, default=128) parser.add_argument("--debug", action="store_true")
parser.add_argument('--debug', action='store_true')
args = parser.parse_args() args = parser.parse_args()
ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)}) ray.init(namespace=os.environ["RAY_NAMESPACE"], runtime_env={"env_vars": dict(os.environ)})
main(args) main(args)
...@@ -4,7 +4,10 @@ from .sft_dataset import SFTDataset, SupervisedDataset ...@@ -4,7 +4,10 @@ from .sft_dataset import SFTDataset, SupervisedDataset
from .utils import is_rank_0 from .utils import is_rank_0
__all__ = [ __all__ = [
'RmStaticDataset', 'HhRlhfDataset', "RmStaticDataset",
'SFTDataset', 'SupervisedDataset', "HhRlhfDataset",
'PromptDataset', 'is_rank_0', "SFTDataset",
"SupervisedDataset",
"PromptDataset",
"is_rank_0",
] ]
...@@ -49,7 +49,7 @@ class Conversation: ...@@ -49,7 +49,7 @@ class Conversation:
def to_gradio_chatbot(self): def to_gradio_chatbot(self):
ret = [] ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]): for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0: if i % 2 == 0:
ret.append([msg, None]) ret.append([msg, None])
else: else:
...@@ -57,12 +57,14 @@ class Conversation: ...@@ -57,12 +57,14 @@ class Conversation:
return ret return ret
def copy(self): def copy(self):
return Conversation(system=self.system, return Conversation(
system=self.system,
roles=self.roles, roles=self.roles,
messages=[[x, y] for x, y in self.messages], messages=[[x, y] for x, y in self.messages],
offset=self.offset, offset=self.offset,
sep_style=self.sep_style, sep_style=self.sep_style,
sep=self.sep) sep=self.sep,
)
def dict(self): def dict(self):
return { return {
...@@ -70,7 +72,7 @@ class Conversation: ...@@ -70,7 +72,7 @@ class Conversation:
"roles": self.roles, "roles": self.roles,
"messages": self.messages, "messages": self.messages,
"offset": self.offset, "offset": self.offset,
"sep": self.sep "sep": self.sep,
} }
......
...@@ -13,11 +13,13 @@ from .utils import jload ...@@ -13,11 +13,13 @@ from .utils import jload
class PromptDataset(Dataset): class PromptDataset(Dataset):
"""Dataset for supervised fine-tuning.""" """Dataset for supervised fine-tuning."""
def __init__(self, def __init__(
self,
data_path: str, data_path: str,
tokenizer: transformers.PreTrainedTokenizer, tokenizer: transformers.PreTrainedTokenizer,
max_datasets_size: int = None, max_datasets_size: int = None,
max_length: int = 96): max_length: int = 96,
):
super(PromptDataset, self).__init__() super(PromptDataset, self).__init__()
self.keyed_prompt = defaultdict(list) self.keyed_prompt = defaultdict(list)
self.logger = get_dist_logger() self.logger = get_dist_logger()
...@@ -30,11 +32,9 @@ class PromptDataset(Dataset): ...@@ -30,11 +32,9 @@ class PromptDataset(Dataset):
list_data_dict = list_data_dict[:max_datasets_size] list_data_dict = list_data_dict[:max_datasets_size]
instructions = [data_dict["instruction"] for data_dict in list_data_dict] instructions = [data_dict["instruction"] for data_dict in list_data_dict]
tokens = tokenizer(instructions, tokens = tokenizer(
return_tensors='pt', instructions, return_tensors="pt", max_length=max_length, padding="max_length", truncation=True
max_length=max_length, )
padding='max_length',
truncation=True)
for k, tensor in tokens.items(): for k, tensor in tokens.items():
self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind() self.keyed_prompt[k] = tensor.to(torch.cuda.current_device()).unbind()
......
...@@ -20,44 +20,31 @@ class RmStaticDataset(Dataset): ...@@ -20,44 +20,31 @@ class RmStaticDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__() super().__init__()
self.end_token = tokenizer.eos_token \ self.end_token = tokenizer.eos_token if special_token is None else special_token
if special_token is None else special_token
chosen = [data["prompt"] + data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
chosen = [ chosen_token = tokenizer(
data["prompt"] + data["chosen"] + self.end_token chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
for data in tqdm(dataset, disable=not is_rank_0()) )
] self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
chosen_token = tokenizer(chosen,
max_length=max_length, reject = [data["prompt"] + data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
padding="max_length", reject_token = tokenizer(
truncation=True, reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
return_tensors="pt") )
self.chosen = { self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
"input_ids": chosen_token["input_ids"],
"attention_mask": chosen_token["attention_mask"]
}
reject = [
data["prompt"] + data["rejected"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject = {
"input_ids": reject_token["input_ids"],
"attention_mask": reject_token["attention_mask"]
}
def __len__(self): def __len__(self):
length = self.chosen["input_ids"].shape[0] length = self.chosen["input_ids"].shape[0]
return length return length
def __getitem__(self, idx): def __getitem__(self, idx):
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \ return (
self.reject["input_ids"][idx], self.reject["attention_mask"][idx] self.chosen["input_ids"][idx],
self.chosen["attention_mask"][idx],
self.reject["input_ids"][idx],
self.reject["attention_mask"][idx],
)
# Anthropic/hh-rlhf # Anthropic/hh-rlhf
...@@ -74,41 +61,28 @@ class HhRlhfDataset(Dataset): ...@@ -74,41 +61,28 @@ class HhRlhfDataset(Dataset):
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None: def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
super().__init__() super().__init__()
self.end_token = tokenizer.eos_token \ self.end_token = tokenizer.eos_token if special_token is None else special_token
if special_token is None else special_token
chosen = [data["chosen"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
chosen = [ chosen_token = tokenizer(
data["chosen"] + self.end_token chosen, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
for data in tqdm(dataset, disable=not is_rank_0()) )
] self.chosen = {"input_ids": chosen_token["input_ids"], "attention_mask": chosen_token["attention_mask"]}
chosen_token = tokenizer(chosen,
max_length=max_length, reject = [data["rejected"] + self.end_token for data in tqdm(dataset, disable=not is_rank_0())]
padding="max_length", reject_token = tokenizer(
truncation=True, reject, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
return_tensors="pt") )
self.chosen = { self.reject = {"input_ids": reject_token["input_ids"], "attention_mask": reject_token["attention_mask"]}
"input_ids": chosen_token["input_ids"],
"attention_mask": chosen_token["attention_mask"]
}
reject = [
data["rejected"] + self.end_token
for data in tqdm(dataset, disable=not is_rank_0())
]
reject_token = tokenizer(reject,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
self.reject = {
"input_ids": reject_token["input_ids"],
"attention_mask": reject_token["attention_mask"]
}
def __len__(self): def __len__(self):
length = self.chosen["input_ids"].shape[0] length = self.chosen["input_ids"].shape[0]
return length return length
def __getitem__(self, idx): def __getitem__(self, idx):
return self.chosen["input_ids"][idx], self.chosen["attention_mask"][idx], \ return (
self.reject["input_ids"][idx], self.reject["attention_mask"][idx] self.chosen["input_ids"][idx],
self.chosen["attention_mask"][idx],
self.reject["input_ids"][idx],
self.reject["attention_mask"][idx],
)
...@@ -16,10 +16,11 @@ import copy ...@@ -16,10 +16,11 @@ import copy
from typing import Dict, Sequence, Tuple from typing import Dict, Sequence, Tuple
import torch import torch
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from torch.utils.data import Dataset from torch.utils.data import Dataset
from tqdm import tqdm from tqdm import tqdm
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from .utils import is_rank_0, jload from .utils import is_rank_0, jload
...@@ -28,32 +29,33 @@ logger = get_dist_logger() ...@@ -28,32 +29,33 @@ logger = get_dist_logger()
IGNORE_INDEX = -100 IGNORE_INDEX = -100
PROMPT_DICT = { PROMPT_DICT = {
"prompt_input": ("Below is an instruction that describes a task, paired with an input that provides further context. " "prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n" "Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"), "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
"prompt_no_input": ("Below is an instruction that describes a task. " ),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n" "Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"), "### Instruction:\n{instruction}\n\n### Response:"
),
} }
def _preprocess(sources: Sequence[str], def _preprocess(
sources: Sequence[str],
targets: Sequence[str], targets: Sequence[str],
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
max_length: int, max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess the data by tokenizing.""" """Preprocess the data by tokenizing."""
sequences = [s + t for s, t in zip(sources, targets)] sequences = [s + t for s, t in zip(sources, targets)]
sequences_token = tokenizer(sequences, sequences_token = tokenizer(
max_length=max_length, sequences, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
padding="max_length", )
truncation=True, sources_token = tokenizer(
return_tensors="pt") sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
sources_token = tokenizer(sources, )
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")
labels = copy.deepcopy(sequences_token["input_ids"]) labels = copy.deepcopy(sequences_token["input_ids"])
for i in range(labels.shape[0]): for i in range(labels.shape[0]):
...@@ -64,18 +66,19 @@ def _preprocess(sources: Sequence[str], ...@@ -64,18 +66,19 @@ def _preprocess(sources: Sequence[str],
labels[i][:source_len] = IGNORE_INDEX labels[i][:source_len] = IGNORE_INDEX
elif tokenizer.padding_side == "left": elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos| # |pad|prompt|completion|eos|
labels[i][pad_len:pad_len + source_len] = IGNORE_INDEX labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX
else: else:
raise RuntimeError() raise RuntimeError()
return sequences_token["input_ids"], labels, sequences_token["attention_mask"] return sequences_token["input_ids"], labels, sequences_token["attention_mask"]
def _preprocess_chatglm(sources: Sequence[str], def _preprocess_chatglm(
sources: Sequence[str],
targets: Sequence[str], targets: Sequence[str],
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
max_length: int, max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Preprocess the data by tokenizing. Preprocess the data by tokenizing.
None for attention mask, ChatGLM will calculate attention mask according to input ids None for attention mask, ChatGLM will calculate attention mask according to input ids
...@@ -90,15 +93,15 @@ def _preprocess_chatglm(sources: Sequence[str], ...@@ -90,15 +93,15 @@ def _preprocess_chatglm(sources: Sequence[str],
# truncate # truncate
sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id] sp_token_list = [tokenizer.gmask_token_id, tokenizer.bos_token_id]
truncate_length = max(0, len(input_id) - max_length) truncate_length = max(0, len(input_id) - max_length)
input_id = input_id[truncate_length: ] input_id = input_id[truncate_length:]
if truncate_length == len(source_id) + 1: if truncate_length == len(source_id) + 1:
input_id = sp_token_list + input_id[1: ] input_id = sp_token_list + input_id[1:]
elif truncate_length > len(source_id) + 1: elif truncate_length > len(source_id) + 1:
input_id = sp_token_list + input_id[2: ] input_id = sp_token_list + input_id[2:]
context_length = input_id.index(tokenizer.bos_token_id) context_length = input_id.index(tokenizer.bos_token_id)
mask_position = context_length - 1 mask_position = context_length - 1
label = [IGNORE_INDEX] * context_length + input_id[mask_position+1:] label = [IGNORE_INDEX] * context_length + input_id[mask_position + 1 :]
pad_len = max_length - len(input_id) pad_len = max_length - len(input_id)
input_id = input_id + [tokenizer.pad_token_id] * pad_len input_id = input_id + [tokenizer.pad_token_id] * pad_len
...@@ -117,25 +120,18 @@ class SFTDataset(Dataset): ...@@ -117,25 +120,18 @@ class SFTDataset(Dataset):
max_length: max length of input max_length: max length of input
""" """
def __init__(self, def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: int = 512) -> None:
dataset: Dict,
tokenizer: PreTrainedTokenizer,
max_length: int = 512
) -> None:
super().__init__() super().__init__()
self.input_ids = [] self.input_ids = []
sources = [data["prompt"] for data in dataset] sources = [data["prompt"] for data in dataset]
targets = [ targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]
data["completion"] + tokenizer.eos_token
for data in tqdm(dataset, disable=not is_rank_0())
]
if isinstance(tokenizer, ChatGLMTokenizer): if isinstance(tokenizer, ChatGLMTokenizer):
self.input_ids, self.labels, self.attention_mask = \ self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
_preprocess_chatglm(sources, targets, tokenizer, max_length) sources, targets, tokenizer, max_length
)
else: else:
self.input_ids, self.labels, self.attention_mask = \ self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
_preprocess(sources, targets, tokenizer, max_length)
def __len__(self): def __len__(self):
length = self.input_ids.shape[0] length = self.input_ids.shape[0]
...@@ -143,22 +139,17 @@ class SFTDataset(Dataset): ...@@ -143,22 +139,17 @@ class SFTDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
if self.attention_mask is not None: if self.attention_mask is not None:
return dict(input_ids=self.input_ids[idx], return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
labels=self.labels[idx],
attention_mask=self.attention_mask[idx])
else: else:
return dict(input_ids=self.input_ids[idx], return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
labels=self.labels[idx])
class SupervisedDataset(Dataset): class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning.""" """Dataset for supervised fine-tuning."""
def __init__(self, def __init__(
data_path: str, self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512
tokenizer: PreTrainedTokenizer, ):
max_datasets_size: int = None,
max_length: int = 512):
super().__init__() super().__init__()
logger.info("Loading data...") logger.info("Loading data...")
list_data_dict = jload(data_path) list_data_dict = jload(data_path)
...@@ -174,18 +165,15 @@ class SupervisedDataset(Dataset): ...@@ -174,18 +165,15 @@ class SupervisedDataset(Dataset):
prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example) prompt_input.format_map(example) if "input" in example else prompt_no_input.format_map(example)
for example in list_data_dict for example in list_data_dict
] ]
targets = [ targets = [example["output"] + tokenizer.eos_token for example in list_data_dict]
example['output'] + tokenizer.eos_token
for example in list_data_dict
]
logger.info("Tokenizing inputs... This may take some time...") logger.info("Tokenizing inputs... This may take some time...")
if isinstance(tokenizer, ChatGLMTokenizer): if isinstance(tokenizer, ChatGLMTokenizer):
self.input_ids, self.labels, self.attention_mask = \ self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
_preprocess_chatglm(sources, targets, tokenizer, max_length) sources, targets, tokenizer, max_length
)
else: else:
self.input_ids, self.labels, self.attention_mask = \ self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)
_preprocess(sources, targets, tokenizer, max_length)
def __len__(self): def __len__(self):
length = self.input_ids.shape[0] length = self.input_ids.shape[0]
...@@ -193,9 +181,6 @@ class SupervisedDataset(Dataset): ...@@ -193,9 +181,6 @@ class SupervisedDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
if self.attention_mask is not None: if self.attention_mask is not None:
return dict(input_ids=self.input_ids[idx], return dict(input_ids=self.input_ids[idx], labels=self.labels[idx], attention_mask=self.attention_mask[idx])
labels=self.labels[idx],
attention_mask=self.attention_mask[idx])
else: else:
return dict(input_ids=self.input_ids[idx], return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
labels=self.labels[idx])
from .base import ExperienceBuffer from .base import ExperienceBuffer
from .naive import NaiveExperienceBuffer from .naive import NaiveExperienceBuffer
__all__ = ['ExperienceBuffer', 'NaiveExperienceBuffer'] __all__ = ["ExperienceBuffer", "NaiveExperienceBuffer"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment