Commit d7117b95 authored by zhouxiang's avatar zhouxiang
Browse files

同步0.2.6代码

parent 5f83e392
accelerate accelerate
datasets datasets
flash-attn transformers_stream_generator
accelerate
mmengine-lite mmengine-lite
pillow
pydantic
torch torch
transformers transformers
urllib3==1.24 urllib3<2.0.0
fastapi
fire fire
mmengine-lite mmengine-lite
numpy numpy
peft<=0.9.0
pillow
pydantic>2.0.0
pynvml
safetensors safetensors
sentencepiece sentencepiece
shortuuid
tiktoken tiktoken
torch torch<=2.1.2,>=2.0.0
transformers==4.33.2 transformers>=4.33.0,<=4.38.1
triton>=2.1.0,<=2.2.0
uvicorn
fastapi gradio<4.0.0
gradio==3.50.2 protobuf
pydantic>2.0.0 tritonclient[grpc]
shortuuid
uvicorn
allure-pytest allure-pytest
coverage coverage
pynvml pynvml
pytest pytest==8.0.2
pytest-assume
pytest-order
pytest-rerunfailures
pytest-sugar
pytest-xdist
pyyaml pyyaml
<svg width="724" height="169" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:space="preserve" overflow="hidden"><defs><clipPath id="clip0"><rect x="290" y="255" width="724" height="169"/></clipPath><linearGradient x1="515.209" y1="187.434" x2="675.945" y2="480.272" gradientUnits="userSpaceOnUse" spreadMethod="pad" id="fill1"><stop offset="0" stop-color="#9C8BFE"/><stop offset="1" stop-color="#2B50FF"/></linearGradient><linearGradient x1="366.983" y1="280.208" x2="358.966" y2="161.282" gradientUnits="userSpaceOnUse" spreadMethod="pad" id="fill2"><stop offset="0" stop-color="#E3AFFE"/><stop offset="1" stop-color="#2B50FF"/></linearGradient><linearGradient x1="339.833" y1="251.78" x2="336.655" y2="198.744" gradientUnits="userSpaceOnUse" spreadMethod="pad" id="fill3"><stop offset="0" stop-color="#748DFA"/><stop offset="1" stop-color="#C1B8FF"/></linearGradient><linearGradient x1="366.61" y1="199.406" x2="331.082" y2="291.3" gradientUnits="userSpaceOnUse" spreadMethod="pad" id="fill4"><stop offset="0" stop-color="#DBABFE"/><stop offset="1" stop-color="#C8F2FF"/></linearGradient><linearGradient x1="369.17" y1="198.557" x2="335.983" y2="245.993" gradientUnits="userSpaceOnUse" spreadMethod="pad" id="stroke5"><stop offset="0" stop-color="#FFFFFF"/><stop offset="0.46875" stop-color="#FFFFFF" stop-opacity="0"/><stop offset="1" stop-color="#FFFFFF" stop-opacity="0"/></linearGradient><linearGradient x1="378.752" y1="221.569" x2="411.083" y2="175.73" gradientUnits="userSpaceOnUse" spreadMethod="pad" id="stroke6"><stop offset="0" stop-color="#FFFFFF"/><stop offset="0.46875" stop-color="#FFFFFF" stop-opacity="0"/><stop offset="1" stop-color="#FFFFFF" stop-opacity="0"/></linearGradient><linearGradient x1="405.519" y1="173.592" x2="409.26" y2="222.227" gradientUnits="userSpaceOnUse" spreadMethod="pad" id="fill7"><stop offset="0" stop-color="#DBABFE"/><stop offset="1" stop-color="#B1E8FA"/></linearGradient><linearGradient x1="356.715" y1="253.912" x2="350.448" y2="271.193" gradientUnits="userSpaceOnUse" spreadMethod="pad" id="stroke8"><stop offset="0" stop-color="#AA5FE6" stop-opacity="0"/><stop offset="1" stop-color="#2E75FE"/></linearGradient><linearGradient x1="350.864" y1="235.329" x2="339.765" y2="259.744" gradientUnits="userSpaceOnUse" spreadMethod="pad" id="stroke9"><stop offset="0" stop-color="#AA5FE6" stop-opacity="0"/><stop offset="1" stop-color="#2E75FE"/></linearGradient><linearGradient x1="353.774" y1="211.139" x2="340.952" y2="235.597" gradientUnits="userSpaceOnUse" spreadMethod="pad" id="stroke10"><stop offset="0" stop-color="#AA5FE6" stop-opacity="0"/><stop offset="1" stop-color="#2E75FE"/></linearGradient></defs><g clip-path="url(#clip0)" transform="translate(-290 -255)"><path d="M0 0 1280.24 0 1280.24 463.908 0 463.908Z" fill="none" transform="matrix(1 0 0 1.00081 -0.255482 128.069)"/><path d="M589.722 261.071 569.151 213.627C567.428 209.675 565.705 205.513 563.982 201.142 563.034 198.739 562.087 196.272 561.14 193.743L560.908 193.122 560.765 192.739 560.606 192.309 560.117 192.309 560.127 192.486 560.156 193.058 560.166 193.275C560.704 203.942 560.972 213.736 560.972 222.652L560.972 261.071 551.023 261.071 551.023 181.62 565.367 181.62 584.594 226.572C586.654 231.396 588.911 237.144 591.365 243.812L591.858 245.158 592.163 245.158 592.21 245.03 592.302 244.777 592.408 244.486C595.227 236.778 597.568 230.803 599.427 226.572L618.654 181.62 632.998 181.62 632.998 261.071 623.049 261.071 623.049 222.652C623.049 214.228 623.299 204.812 623.8 194.403L623.855 193.272 623.866 193.058 623.894 192.486 623.904 192.309 623.415 192.309 623.146 193.032 623.114 193.119C620.214 200.895 617.465 207.736 614.869 213.627L594.3 261.071 589.722 261.071ZM718.209 234.053C719.389 229.975 719.979 225.622 719.979 220.99 719.979 216.398 719.389 212.121 718.209 208.164 717.07 204.165 715.32 200.582 712.96 197.415 710.64 194.209 707.69 191.439 704.109 189.102 700.569 186.766 696.418 184.945 691.658 183.639 688.93 182.886 685.981 182.371 682.807 182.095 681.362 181.951 679.864 181.839 678.31 181.761 676.45 181.667 674.509 181.62 672.491 181.62L653.813 181.62 653.813 261.071 672.308 261.071C674.622 261.071 676.833 261.017 678.941 260.909L679.099 260.902 679.955 260.853C680.864 260.798 681.754 260.732 682.623 260.656 685.797 260.339 688.748 259.806 691.474 259.053 696.235 257.747 700.386 255.905 703.926 253.53 707.507 251.156 710.477 248.344 712.838 245.098 715.239 241.813 717.029 238.132 718.209 234.053ZM704.354 202.403C707.772 207.154 709.481 213.349 709.481 220.99 709.481 228.708 707.772 235.044 704.354 239.992 700.935 244.94 696.032 248.404 689.643 250.383 688.628 250.695 687.564 250.967 686.451 251.198 684.886 251.521 683.223 251.765 681.464 251.927 678.452 252.205 675.054 252.342 671.27 252.342L663.762 252.342 663.762 190.349 671.27 190.349C675.054 190.349 678.452 190.487 681.464 190.764 684.474 191.042 687.202 191.557 689.643 192.309 696.032 194.288 700.935 197.652 704.354 202.403ZM842.824 232.331C842.824 229.046 842.458 226.097 841.724 223.484 841.036 220.871 839.974 218.654 838.552 216.833 837.13 214.973 835.379 213.548 833.299 212.557 832.611 212.23 831.893 211.955 831.129 211.736 830.875 211.663 830.621 211.595 830.359 211.534 829.042 211.227 827.62 211.074 826.101 211.074 824.148 211.074 822.353 211.331 820.729 211.845 819.143 212.359 817.698 213.013 816.396 213.805 815.095 214.596 813.972 215.447 813.037 216.358 812.102 217.269 811.346 218.08 810.777 218.793L810.777 248.305C812.775 250.285 815.027 251.848 817.556 252.995 818.813 253.56 820.115 253.986 821.47 254.274 822.001 254.388 822.547 254.48 823.101 254.55 823.976 254.662 824.867 254.719 825.779 254.719L825.794 254.719C826.752 254.719 827.74 254.628 828.757 254.447 829.64 254.289 830.546 254.062 831.474 253.768 833.464 253.095 835.297 251.927 836.966 250.264 838.672 248.561 840.078 246.266 841.178 243.377 842.271 240.486 842.824 236.804 842.824 232.331ZM852.417 237.663C852.215 239.443 851.908 241.11 851.489 242.663 850.718 245.672 849.678 248.305 848.376 250.561 847.074 252.817 845.548 254.719 843.797 256.262 842.794 257.129 841.769 257.902 840.722 258.584 839.944 259.087 839.165 259.54 838.365 259.943 836.494 260.853 834.586 261.507 832.633 261.902 830.718 262.338 828.907 262.556 827.201 262.556 823.654 262.556 820.587 262.002 817.983 260.893 815.701 259.889 813.628 258.476 811.75 256.654 811.608 256.511 811.466 256.369 811.323 256.223 811.196 256.098 811.077 255.974 810.964 255.847L810.777 255.847 810.777 288.684 801.319 288.684 801.319 204.659 810.471 204.659 810.471 211.074 810.658 211.074C810.83 210.813 811.017 210.556 811.219 210.302 811.638 209.779 812.124 209.263 812.67 208.757 813.523 208.005 814.646 207.213 816.03 206.382 816.703 205.97 817.437 205.586 818.222 205.226 819.105 204.826 820.063 204.459 821.095 204.125 821.904 203.869 822.734 203.664 823.595 203.512 824.859 203.288 826.184 203.175 827.56 203.175 830.823 203.175 833.95 203.71 836.966 204.778 840.018 205.847 842.705 207.55 845.024 209.886 847.344 212.221 849.192 215.23 850.576 218.911 851.998 222.553 852.709 226.987 852.709 232.213 852.709 234.135 852.611 235.951 852.417 237.663ZM935.995 232.925C935.995 229.442 935.546 226.354 934.656 223.662 933.803 220.931 932.583 218.633 930.989 216.773 929.448 214.873 927.592 213.449 925.437 212.498 923.282 211.508 920.918 211.014 918.359 211.014 915.793 211.014 913.436 211.508 911.273 212.498 909.118 213.449 907.248 214.873 905.661 216.773 904.794 217.818 904.03 219.001 903.364 220.32 902.848 221.35 902.392 222.464 902.003 223.662 901.142 226.354 900.716 229.442 900.716 232.925 900.716 236.369 901.142 239.457 902.003 242.189 902.893 244.88 904.135 247.157 905.721 249.018 907.308 250.878 909.178 252.303 911.341 253.293 913.496 254.244 915.853 254.719 918.419 254.719 920.985 254.719 923.32 254.244 925.437 253.293 926.979 252.588 928.356 251.661 929.583 250.514 930.077 250.051 930.548 249.553 930.989 249.018 932.583 247.157 933.803 244.88 934.656 242.189 935.546 239.457 935.995 236.369 935.995 232.925ZM945.887 232.925C945.887 237.359 945.236 241.396 943.934 245.039 942.632 248.681 940.776 251.809 938.382 254.422 936.018 256.994 933.152 258.993 929.77 260.418 926.395 261.844 922.609 262.556 918.419 262.556 914.102 262.556 910.241 261.844 906.821 260.418 904.921 259.617 903.177 258.633 901.584 257.469 901.411 257.339 901.232 257.209 901.06 257.074 900.02 256.27 899.055 255.386 898.157 254.422 895.792 251.809 893.982 248.681 892.717 245.039 892.014 242.991 891.505 240.816 891.191 238.517L891.153 238.216C890.936 236.52 890.831 234.756 890.831 232.925 890.831 228.451 891.482 224.394 892.784 220.752 893.518 218.704 894.415 216.818 895.478 215.096L895.538 215.007C896.353 213.7 897.266 212.488 898.276 211.37 900.678 208.757 903.566 206.738 906.941 205.313 910.36 203.888 914.169 203.175 918.359 203.175 922.632 203.175 926.477 203.888 929.897 205.313 933.309 206.738 936.205 208.757 938.562 211.37 940.919 213.983 942.729 217.111 943.994 220.752 944.458 222.092 944.839 223.489 945.131 224.939L945.191 225.231C945.655 227.641 945.887 230.206 945.887 232.925ZM976.587 259.943 964.196 288.684 973.602 288.684 1009.79 204.659 999.842 204.659 981.593 248.899 981.226 248.899 963.224 204.659 953.519 204.659 976.587 259.943ZM787.896 235.787C785.352 249.711 773.073 260.247 758.79 260.247 742.478 260.247 729.206 246.852 729.206 230.387 729.206 220.419 734.071 211.575 741.531 206.149 746.423 202.538 752.452 200.404 758.962 200.404 771.248 200.404 782.067 207.869 786.527 219.398L788.24 223.81 788.046 223.884 788.068 223.934 742.402 241.925C742.482 242.04 742.561 242.152 742.643 242.264L742.821 242.506C743.147 242.943 743.49 243.366 743.849 243.774 747.55 247.979 752.961 250.636 758.962 250.636 763.361 250.636 767.477 249.173 770.836 246.695 775.326 243.309 778.416 238.091 778.91 232.132L779.067 232.146C779.067 232.138 779.067 232.132 779.067 232.126L779.074 232.008 788.442 232.811C788.367 233.768 788.24 234.712 788.068 235.642L787.896 235.787ZM747.047 213.806C742.027 217.517 738.759 223.504 738.759 230.246 738.759 230.51 738.765 230.775 738.776 231.039L738.782 231.207 738.787 231.304C738.801 231.567 738.82 231.83 738.843 232.09L738.86 232.267 738.901 232.641 738.949 233.037 775.356 218.692C771.637 213.294 765.531 209.998 758.775 209.998 754.405 210.001 750.357 211.413 747.047 213.806ZM535.763 252.342 535.763 261.071 485.955 261.071 485.955 181.62 495.904 181.62 495.904 252.342 535.763 252.342ZM865.743 175.088 875.201 175.088 875.201 261.071 865.743 261.071 865.743 175.088Z" fill="none" fill-rule="evenodd" transform="matrix(1 0 0 1.00081 -0.255482 128.069)"/><path d="M589.722 261.071 569.151 213.627C567.428 209.675 565.705 205.513 563.982 201.142 563.034 198.739 562.087 196.272 561.14 193.743L560.908 193.122 560.765 192.739 560.606 192.309 560.117 192.309 560.127 192.486 560.156 193.058 560.166 193.275C560.704 203.942 560.972 213.736 560.972 222.652L560.972 261.071 551.023 261.071 551.023 181.62 565.367 181.62 584.594 226.572C586.654 231.396 588.911 237.144 591.365 243.812L591.858 245.158 592.163 245.158 592.21 245.03 592.302 244.777 592.408 244.486C595.227 236.778 597.568 230.803 599.427 226.572L618.654 181.62 632.998 181.62 632.998 261.071 623.049 261.071 623.049 222.652C623.049 214.228 623.299 204.812 623.8 194.403L623.855 193.272 623.866 193.058 623.894 192.486 623.904 192.309 623.415 192.309 623.146 193.032 623.114 193.119C620.214 200.895 617.465 207.736 614.869 213.627L594.3 261.071 589.722 261.071ZM718.209 234.053C719.389 229.975 719.979 225.622 719.979 220.99 719.979 216.398 719.389 212.121 718.209 208.164 717.07 204.165 715.32 200.582 712.96 197.415 710.64 194.209 707.69 191.439 704.109 189.102 700.569 186.766 696.418 184.945 691.658 183.639 688.93 182.886 685.981 182.371 682.807 182.095 681.362 181.951 679.864 181.839 678.31 181.761 676.45 181.667 674.509 181.62 672.491 181.62L653.813 181.62 653.813 261.071 672.308 261.071C674.622 261.071 676.833 261.017 678.941 260.909L679.099 260.902 679.955 260.853C680.864 260.798 681.754 260.732 682.623 260.656 685.797 260.339 688.748 259.806 691.474 259.053 696.235 257.747 700.386 255.905 703.926 253.53 707.507 251.156 710.477 248.344 712.838 245.098 715.239 241.813 717.029 238.132 718.209 234.053ZM704.354 202.403C707.772 207.154 709.481 213.349 709.481 220.99 709.481 228.708 707.772 235.044 704.354 239.992 700.935 244.94 696.032 248.404 689.643 250.383 688.628 250.695 687.564 250.967 686.451 251.198 684.886 251.521 683.223 251.765 681.464 251.927 678.452 252.205 675.054 252.342 671.27 252.342L663.762 252.342 663.762 190.349 671.27 190.349C675.054 190.349 678.452 190.487 681.464 190.764 684.474 191.042 687.202 191.557 689.643 192.309 696.032 194.288 700.935 197.652 704.354 202.403ZM842.824 232.331C842.824 229.046 842.458 226.097 841.724 223.484 841.036 220.871 839.974 218.654 838.552 216.833 837.13 214.973 835.379 213.548 833.299 212.557 832.611 212.23 831.893 211.955 831.129 211.736 830.875 211.663 830.621 211.595 830.359 211.534 829.042 211.227 827.62 211.074 826.101 211.074 824.148 211.074 822.353 211.331 820.729 211.845 819.143 212.359 817.698 213.013 816.396 213.805 815.095 214.596 813.972 215.447 813.037 216.358 812.102 217.269 811.346 218.08 810.777 218.793L810.777 248.305C812.775 250.285 815.027 251.848 817.556 252.995 818.813 253.56 820.115 253.986 821.47 254.274 822.001 254.388 822.547 254.48 823.101 254.55 823.976 254.662 824.867 254.719 825.779 254.719L825.794 254.719C826.752 254.719 827.74 254.628 828.757 254.447 829.64 254.289 830.546 254.062 831.474 253.768 833.464 253.095 835.297 251.927 836.966 250.264 838.672 248.561 840.078 246.266 841.178 243.377 842.271 240.486 842.824 236.804 842.824 232.331ZM852.417 237.663C852.215 239.443 851.908 241.11 851.489 242.663 850.718 245.672 849.678 248.305 848.376 250.561 847.074 252.817 845.548 254.719 843.797 256.262 842.794 257.129 841.769 257.902 840.722 258.584 839.944 259.087 839.165 259.54 838.365 259.943 836.494 260.853 834.586 261.507 832.633 261.902 830.718 262.338 828.907 262.556 827.201 262.556 823.654 262.556 820.587 262.002 817.983 260.893 815.701 259.889 813.628 258.476 811.75 256.654 811.608 256.511 811.466 256.369 811.323 256.223 811.196 256.098 811.077 255.974 810.964 255.847L810.777 255.847 810.777 288.684 801.319 288.684 801.319 204.659 810.471 204.659 810.471 211.074 810.658 211.074C810.83 210.813 811.017 210.556 811.219 210.302 811.638 209.779 812.124 209.263 812.67 208.757 813.523 208.005 814.646 207.213 816.03 206.382 816.703 205.97 817.437 205.586 818.222 205.226 819.105 204.826 820.063 204.459 821.095 204.125 821.904 203.869 822.734 203.664 823.595 203.512 824.859 203.288 826.184 203.175 827.56 203.175 830.823 203.175 833.95 203.71 836.966 204.778 840.018 205.847 842.705 207.55 845.024 209.886 847.344 212.221 849.192 215.23 850.576 218.911 851.998 222.553 852.709 226.987 852.709 232.213 852.709 234.135 852.611 235.951 852.417 237.663ZM935.995 232.925C935.995 229.442 935.546 226.354 934.656 223.662 933.803 220.931 932.583 218.633 930.989 216.773 929.448 214.873 927.592 213.449 925.437 212.498 923.282 211.508 920.918 211.014 918.359 211.014 915.793 211.014 913.436 211.508 911.273 212.498 909.118 213.449 907.248 214.873 905.661 216.773 904.794 217.818 904.03 219.001 903.364 220.32 902.848 221.35 902.392 222.464 902.003 223.662 901.142 226.354 900.716 229.442 900.716 232.925 900.716 236.369 901.142 239.457 902.003 242.189 902.893 244.88 904.135 247.157 905.721 249.018 907.308 250.878 909.178 252.303 911.341 253.293 913.496 254.244 915.853 254.719 918.419 254.719 920.985 254.719 923.32 254.244 925.437 253.293 926.979 252.588 928.356 251.661 929.583 250.514 930.077 250.051 930.548 249.553 930.989 249.018 932.583 247.157 933.803 244.88 934.656 242.189 935.546 239.457 935.995 236.369 935.995 232.925ZM945.887 232.925C945.887 237.359 945.236 241.396 943.934 245.039 942.632 248.681 940.776 251.809 938.382 254.422 936.018 256.994 933.152 258.993 929.77 260.418 926.395 261.844 922.609 262.556 918.419 262.556 914.102 262.556 910.241 261.844 906.821 260.418 904.921 259.617 903.177 258.633 901.584 257.469 901.411 257.339 901.232 257.209 901.06 257.074 900.02 256.27 899.055 255.386 898.157 254.422 895.792 251.809 893.982 248.681 892.717 245.039 892.014 242.991 891.505 240.816 891.191 238.517L891.153 238.216C890.936 236.52 890.831 234.756 890.831 232.925 890.831 228.451 891.482 224.394 892.784 220.752 893.518 218.704 894.415 216.818 895.478 215.096L895.538 215.007C896.353 213.7 897.266 212.488 898.276 211.37 900.678 208.757 903.566 206.738 906.941 205.313 910.36 203.888 914.169 203.175 918.359 203.175 922.632 203.175 926.477 203.888 929.897 205.313 933.309 206.738 936.205 208.757 938.562 211.37 940.919 213.983 942.729 217.111 943.994 220.752 944.458 222.092 944.839 223.489 945.131 224.939L945.191 225.231C945.655 227.641 945.887 230.206 945.887 232.925ZM976.587 259.943 964.196 288.684 973.602 288.684 1009.79 204.659 999.842 204.659 981.593 248.899 981.226 248.899 963.224 204.659 953.519 204.659 976.587 259.943ZM787.896 235.787C785.352 249.711 773.073 260.247 758.79 260.247 742.478 260.247 729.206 246.852 729.206 230.387 729.206 220.419 734.071 211.575 741.531 206.149 746.423 202.538 752.452 200.404 758.962 200.404 771.248 200.404 782.067 207.869 786.527 219.398L788.24 223.81 788.046 223.884 788.068 223.934 742.402 241.925C742.482 242.04 742.561 242.152 742.643 242.264L742.821 242.506C743.147 242.943 743.49 243.366 743.849 243.774 747.55 247.979 752.961 250.636 758.962 250.636 763.361 250.636 767.477 249.173 770.836 246.695 775.326 243.309 778.416 238.091 778.91 232.132L779.067 232.146C779.067 232.138 779.067 232.132 779.067 232.126L779.074 232.008 788.442 232.811C788.367 233.768 788.24 234.712 788.068 235.642L787.896 235.787ZM747.047 213.806C742.027 217.517 738.759 223.504 738.759 230.246 738.759 230.51 738.765 230.775 738.776 231.039L738.782 231.207 738.787 231.304C738.801 231.567 738.82 231.83 738.843 232.09L738.86 232.267 738.901 232.641 738.949 233.037 775.356 218.692C771.637 213.294 765.531 209.998 758.775 209.998 754.405 210.001 750.357 211.413 747.047 213.806ZM535.763 252.342 535.763 261.071 485.955 261.071 485.955 181.62 495.904 181.62 495.904 252.342 535.763 252.342ZM865.743 175.088 875.201 175.088 875.201 261.071 865.743 261.071 865.743 175.088Z" fill="url(#fill1)" fill-rule="evenodd" transform="matrix(1 0 0 1.00081 -0.255482 128.069)"/><path d="M417.928 210.759 332.03 292.638 356.588 212.584 329.253 211.565 415.752 129.412 390.657 209.626 417.928 210.759Z" fill="url(#fill2)" fill-rule="evenodd" transform="matrix(1 0 0 1.00081 -0.255482 128.069)"/><path d="M352.974 215.897 331.46 292.898 370.665 200.078C370.913 199.492 370.362 198.884 369.754 199.072L328.536 211.86 352.35 214.954C352.802 215.013 353.097 215.459 352.974 215.897Z" fill="url(#fill3)" transform="matrix(1 0 0 1.00081 -0.255482 128.069)"/><path d="M352.974 215.897 331.46 292.898 370.665 200.078C370.913 199.492 370.362 198.884 369.754 199.072L328.536 211.86 352.35 214.954C352.802 215.013 353.097 215.459 352.974 215.897Z" fill="url(#fill4)" transform="matrix(1 0 0 1.00081 -0.255482 128.069)"/><path d="M352.974 215.897 331.46 292.898 370.665 200.078C370.913 199.492 370.362 198.884 369.754 199.072L328.536 211.86 352.35 214.954C352.802 215.013 353.097 215.459 352.974 215.897Z" stroke="url(#stroke5)" stroke-width="0.748239" fill="none" transform="matrix(1 0 0 1.00081 -0.255482 128.069)"/><path d="M394.247 202.173 415.328 129.974 377.297 220.145C377.057 220.715 377.573 221.314 378.172 221.161L417.509 211.1 394.716 203.089C394.342 202.957 394.135 202.554 394.247 202.173Z" stroke="url(#stroke6)" stroke-width="0.748239" fill="url(#fill7)" transform="matrix(1 0 0 1.00081 -0.255482 128.069)"/><path d="M400.69 240.126C415.788 244.356 425.536 251.018 425.453 258.426 425.315 270.82 397.71 280.608 363.797 280.288 329.883 279.969 302.503 269.662 302.641 257.268 302.735 248.864 315.458 241.657 334.215 237.989" stroke="url(#stroke8)" stroke-width="5.23768" fill="none" transform="matrix(1 0 0 1.00081 -0.255482 128.069)"/><path d="M403.693 233.437C417.578 241.42 425.394 250.68 423.145 258.396 419.383 271.306 388.87 275.007 354.995 266.662 321.119 258.317 296.707 241.086 300.47 228.176 303.021 219.421 317.873 214.902 337.734 215.501" stroke="url(#stroke9)" stroke-width="5.23768" fill="none" transform="matrix(1 0 0 1.00081 -0.255482 128.069)"/><path d="M403.498 232.586C414.855 243.273 420.115 253.555 416.138 259.89 409.483 270.487 379.485 266.019 349.137 249.91 318.787 233.801 299.58 212.151 306.236 201.553 310.748 194.367 325.995 194.108 344.807 199.71" stroke="url(#stroke10)" stroke-width="5.23768" fill="none" transform="matrix(1 0 0 1.00081 -0.255482 128.069)"/></g></svg>
...@@ -84,6 +84,32 @@ def check_ext_modules(): ...@@ -84,6 +84,32 @@ def check_ext_modules():
return False return False
def get_cuda_pkgs():
arg_name = '--cuda='
arg_value = None
for arg in sys.argv[1:]:
if arg.startswith(arg_name):
arg_value = arg[len(arg_name):]
sys.argv.remove(arg)
break
cuda_pkgs = []
if arg_value == '11':
cuda_pkgs = [
'nvidia-nccl-cu11', 'nvidia-cuda-runtime-cu11',
'nvidia-cublas-cu11'
]
elif arg_value == '12':
cuda_pkgs = [
'nvidia-nccl-cu12', 'nvidia-cuda-runtime-cu12',
'nvidia-cublas-cu12'
]
return cuda_pkgs
cuda_pkgs = get_cuda_pkgs()
def parse_requirements(fname='requirements.txt', with_version=True): def parse_requirements(fname='requirements.txt', with_version=True):
"""Parse the package dependencies listed in a file but strips specific """Parse the package dependencies listed in a file but strips specific
versioning information. versioning information.
...@@ -100,21 +126,6 @@ def parse_requirements(fname='requirements.txt', with_version=True): ...@@ -100,21 +126,6 @@ def parse_requirements(fname='requirements.txt', with_version=True):
""" """
require_fpath = fname require_fpath = fname
def get_nccl_pkg():
arg_name = '--cuda='
arg_value = None
for arg in sys.argv[1:]:
if arg.startswith(arg_name):
arg_value = arg[len(arg_name):]
sys.argv.remove(arg)
break
if arg_value == '11':
return 'nvidia-nccl-cu11'
elif arg_value == '12':
return 'nvidia-nccl-cu12'
return None
def parse_line(line): def parse_line(line):
"""Parse information from a line in a requirements text file.""" """Parse information from a line in a requirements text file."""
if line.startswith('-r '): if line.startswith('-r '):
...@@ -171,9 +182,7 @@ def parse_requirements(fname='requirements.txt', with_version=True): ...@@ -171,9 +182,7 @@ def parse_requirements(fname='requirements.txt', with_version=True):
yield item yield item
packages = list(gen_packages_items()) packages = list(gen_packages_items())
nccl_pkg = get_nccl_pkg() packages += cuda_pkgs
if nccl_pkg is not None:
packages += [nccl_pkg]
return packages return packages
......
...@@ -196,7 +196,7 @@ __global__ void generic_activation(T* out, ...@@ -196,7 +196,7 @@ __global__ void generic_activation(T* out,
using Float_T = typename packed_as<float, packed_elems>::type; using Float_T = typename packed_as<float, packed_elems>::type;
using Packed_Int8_t = typename packed_as<int8_t, packed_elems>::type; using Packed_Int8_t = typename packed_as<int8_t, packed_elems>::type;
for (int id = blockIdx.x * blockDim.x + threadIdx.x; id < m * n; id += blockDim.x * gridDim.x) { for (int64_t id = blockIdx.x * blockDim.x + threadIdx.x; id < 1LL * m * n; id += blockDim.x * gridDim.x) {
T val; T val;
if (int8_mode == 2) { if (int8_mode == 2) {
// val = cuda_cast<T>(cuda_cast<Float_T>(reinterpret_cast<Packed_Int8_t*>(out)[id]) * activation_in[0]); // val = cuda_cast<T>(cuda_cast<Float_T>(reinterpret_cast<Packed_Int8_t*>(out)[id]) * activation_in[0]);
...@@ -277,7 +277,7 @@ void invokeGenericActivation(T* out, ...@@ -277,7 +277,7 @@ void invokeGenericActivation(T* out,
} }
else { else {
block.x = n_threads; block.x = n_threads;
grid.x = ceil(m * n / double(n_threads)); grid.x = ceil(1LL * m * n / double(n_threads));
} }
TM_LOG_DEBUG("%d %d", grid.x, block.x); TM_LOG_DEBUG("%d %d", grid.x, block.x);
sync_check_cuda_error(); sync_check_cuda_error();
......
...@@ -367,6 +367,7 @@ template void invokeApplyRepetitionPenalty(half* logits, ...@@ -367,6 +367,7 @@ template void invokeApplyRepetitionPenalty(half* logits,
template<typename T, RepetitionPenaltyType penalty_type> template<typename T, RepetitionPenaltyType penalty_type>
__global__ void batchApplyRepetitionPenalty(T* logits, __global__ void batchApplyRepetitionPenalty(T* logits,
const float* penalties, const float* penalties,
int* penalty_workspace,
const int* output_ids, const int* output_ids,
const int batch_size, const int batch_size,
const int vocab_size, const int vocab_size,
...@@ -374,11 +375,12 @@ __global__ void batchApplyRepetitionPenalty(T* logits, ...@@ -374,11 +375,12 @@ __global__ void batchApplyRepetitionPenalty(T* logits,
const int max_input_length, const int max_input_length,
const int step) const int step)
{ {
extern __shared__ float penalty_logits[];
int* penalty_indices = (int*)(penalty_logits + step);
const int batch_idx = blockIdx.x; const int batch_idx = blockIdx.x;
const float penalty = penalties[batch_idx]; const float penalty = penalties[batch_idx];
const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length; const int input_length = input_lengths != nullptr ? input_lengths[batch_idx] : max_input_length;
penalty_workspace += batch_idx * step * 2;
float* penalty_logits = (float*)penalty_workspace;
int* penalty_indices = (int*)(penalty_workspace + step);
logits += batch_idx * vocab_size; logits += batch_idx * vocab_size;
...@@ -409,10 +411,6 @@ __global__ void batchApplyRepetitionPenalty(T* logits, ...@@ -409,10 +411,6 @@ __global__ void batchApplyRepetitionPenalty(T* logits,
} }
} }
if (blockDim.x > 32) {
__syncthreads();
}
// Phase 2. Replace a logit value by the penalized one. // Phase 2. Replace a logit value by the penalized one.
for (int index = threadIdx.x; index < step; index += blockDim.x) { for (int index = threadIdx.x; index < step; index += blockDim.x) {
// Skip the padding tokens in input sequences. // Skip the padding tokens in input sequences.
...@@ -426,6 +424,7 @@ __global__ void batchApplyRepetitionPenalty(T* logits, ...@@ -426,6 +424,7 @@ __global__ void batchApplyRepetitionPenalty(T* logits,
template<typename T> template<typename T>
void invokeBatchApplyRepetitionPenalty(T* logits, void invokeBatchApplyRepetitionPenalty(T* logits,
const float* penalties, const float* penalties,
int* penalty_workspace,
const int* output_ids, const int* output_ids,
const int batch_size, const int batch_size,
const int local_batch_size, const int local_batch_size,
...@@ -442,22 +441,30 @@ void invokeBatchApplyRepetitionPenalty(T* logits, ...@@ -442,22 +441,30 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
// output_ids [step, batch_size] : output token ids (with offset ite * local_batch_size). // output_ids [step, batch_size] : output token ids (with offset ite * local_batch_size).
// input_lengths [local_batch_size], input lengths (optional). // input_lengths [local_batch_size], input lengths (optional).
// Padding tokens at [input_length, max_input_length) of input will not be penalized. // Padding tokens at [input_length, max_input_length) of input will not be penalized.
dim3 block(min(step, 1024)); dim3 block(min(step, 1024));
dim3 grid(local_batch_size); dim3 grid(local_batch_size);
size_t smem_size = step * (sizeof(float) + sizeof(int));
if (penalty_type == RepetitionPenaltyType::Additive) { if (penalty_type == RepetitionPenaltyType::Additive) {
check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive>, batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, 0, stream>>>(logits,
cudaFuncAttributeMaxDynamicSharedMemorySize, penalties,
smem_size)); penalty_workspace,
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Additive><<<grid, block, smem_size, stream>>>( output_ids,
logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); batch_size,
vocab_size,
input_lengths,
max_input_length,
step);
} }
else if (penalty_type == RepetitionPenaltyType::Multiplicative) { else if (penalty_type == RepetitionPenaltyType::Multiplicative) {
check_cuda_error(cudaFuncSetAttribute(batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative>, batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative>
cudaFuncAttributeMaxDynamicSharedMemorySize, <<<grid, block, 0, stream>>>(logits,
smem_size)); penalties,
batchApplyRepetitionPenalty<T, RepetitionPenaltyType::Multiplicative><<<grid, block, smem_size, stream>>>( penalty_workspace,
logits, penalties, output_ids, batch_size, vocab_size, input_lengths, max_input_length, step); output_ids,
batch_size,
vocab_size,
input_lengths,
max_input_length,
step);
} }
else if (penalty_type == RepetitionPenaltyType::None) { else if (penalty_type == RepetitionPenaltyType::None) {
// do nothing // do nothing
...@@ -466,6 +473,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits, ...@@ -466,6 +473,7 @@ void invokeBatchApplyRepetitionPenalty(T* logits,
template void invokeBatchApplyRepetitionPenalty(float* logits, template void invokeBatchApplyRepetitionPenalty(float* logits,
const float* penalties, const float* penalties,
int* penalty_workspace,
const int* output_ids, const int* output_ids,
const int batch_size, const int batch_size,
const int local_batch_size, const int local_batch_size,
...@@ -478,6 +486,7 @@ template void invokeBatchApplyRepetitionPenalty(float* logits, ...@@ -478,6 +486,7 @@ template void invokeBatchApplyRepetitionPenalty(float* logits,
template void invokeBatchApplyRepetitionPenalty(half* logits, template void invokeBatchApplyRepetitionPenalty(half* logits,
const float* penalties, const float* penalties,
int* penalty_workspace,
const int* output_ids, const int* output_ids,
const int batch_size, const int batch_size,
const int local_batch_size, const int local_batch_size,
...@@ -497,9 +506,8 @@ __global__ void batchApplyMinLengthPenalty(T* logits, ...@@ -497,9 +506,8 @@ __global__ void batchApplyMinLengthPenalty(T* logits,
const int vocab_size_padded) const int vocab_size_padded)
{ {
int bid = threadIdx.x + blockIdx.x * blockDim.x; // batch index int bid = threadIdx.x + blockIdx.x * blockDim.x; // batch index
// We need +1 because sequence_lengths = max_input_length + num_gen_tokens - 1, // In decoder, sequence_lengths means length of sequence that has kv cache already computed
// which is equal to the length of k/v caches. if (sequence_lengths[bid] + 1 < min_lengths[bid]) {
if (sequence_lengths[bid] + 1 - max_input_length < min_lengths[bid]) {
T mask_val = (std::is_same<T, half>::value) ? -65504.0f : -FLT_MAX; T mask_val = (std::is_same<T, half>::value) ? -65504.0f : -FLT_MAX;
logits[bid * vocab_size_padded + end_ids[bid]] = mask_val; logits[bid * vocab_size_padded + end_ids[bid]] = mask_val;
} }
......
...@@ -40,6 +40,7 @@ void invokeApplyRepetitionPenalty(T* logits, ...@@ -40,6 +40,7 @@ void invokeApplyRepetitionPenalty(T* logits,
template<typename T> template<typename T>
void invokeBatchApplyRepetitionPenalty(T* logits, void invokeBatchApplyRepetitionPenalty(T* logits,
const float* penalties, const float* penalties,
int* penalty_workspace,
const int* output_ids, const int* output_ids,
const int batch_size, const int batch_size,
const int local_batch_size, const int local_batch_size,
......
...@@ -45,6 +45,7 @@ void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tenso ...@@ -45,6 +45,7 @@ void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tenso
repetition_penalty_ = (float*)std::realloc((void*)repetition_penalty_, batch_size * sizeof(float)); repetition_penalty_ = (float*)std::realloc((void*)repetition_penalty_, batch_size * sizeof(float));
min_lengths_ = (int*)std::realloc((void*)min_lengths_, batch_size * sizeof(int)); min_lengths_ = (int*)std::realloc((void*)min_lengths_, batch_size * sizeof(int));
skip_decode_ = (bool*)std::realloc((void*)skip_decode_, batch_size * sizeof(bool)); skip_decode_ = (bool*)std::realloc((void*)skip_decode_, batch_size * sizeof(bool));
context_length_ = (int*)std::realloc((void*)context_length_, batch_size * sizeof(int));
is_allocate_buffer_ = true; is_allocate_buffer_ = true;
} }
...@@ -54,6 +55,7 @@ void BaseSamplingLayer<T>::freeBuffer() ...@@ -54,6 +55,7 @@ void BaseSamplingLayer<T>::freeBuffer()
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (is_allocate_buffer_) { if (is_allocate_buffer_) {
allocator_->free((void**)(&repetition_penalty_workspace_));
allocator_->free((void**)(&temperature_buf_)); allocator_->free((void**)(&temperature_buf_));
allocator_->free((void**)(&repetition_penalty_buf_)); allocator_->free((void**)(&repetition_penalty_buf_));
allocator_->free((void**)(&min_lengths_buf_)); allocator_->free((void**)(&min_lengths_buf_));
...@@ -63,6 +65,7 @@ void BaseSamplingLayer<T>::freeBuffer() ...@@ -63,6 +65,7 @@ void BaseSamplingLayer<T>::freeBuffer()
std::free(repetition_penalty_); std::free(repetition_penalty_);
std::free(min_lengths_); std::free(min_lengths_);
std::free(skip_decode_); std::free(skip_decode_);
std::free(context_length_);
is_allocate_buffer_ = false; is_allocate_buffer_ = false;
} }
} }
...@@ -161,16 +164,23 @@ void BaseSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt ...@@ -161,16 +164,23 @@ void BaseSamplingLayer<T>::setup(const size_t batch_size, const size_t beam_widt
repetition_penalty_type_ = RepetitionPenaltyType::None; repetition_penalty_type_ = RepetitionPenaltyType::None;
} }
const int default_min_length = 0; // min_length
Tensor min_lengths = runtime_args->at("min_length", Tensor(MEMORY_CPU, TYPE_INT32, {1}, &default_min_length)); if (runtime_args->isExist("min_length")) {
if (min_lengths.size() == 1) { Tensor min_lengths = runtime_args->at("min_length");
int minlen = min_lengths.getVal<int>(); Tensor context_lengths = runtime_args->at("context_length");
deviceFill(min_lengths_buf_, batch_size, minlen, stream_); Tensor prompt_lengths = runtime_args->at("prompt_length");
std::fill_n(min_lengths_, batch_size, minlen); auto p1 = min_lengths.getPtr<int>();
auto p2 = prompt_lengths.getPtr<int>();
for (int i = 0; i < batch_size; i++) {
min_lengths_[i] = p1[i] + p2[i];
}
cudaAutoCpy(min_lengths_buf_, min_lengths_, batch_size, stream_);
std::copy_n(context_lengths.getPtr<int>(), batch_size, context_length_);
} }
else { else {
cudaAutoCpy(min_lengths_buf_, min_lengths.getPtr<int>(), batch_size, stream_); std::fill_n(min_lengths_, batch_size, 0);
std::copy_n(min_lengths.getPtr<int>(), batch_size, min_lengths_); deviceFill(min_lengths_buf_, batch_size, 0, stream_);
std::fill_n(context_length_, batch_size, 0);
} }
} }
...@@ -284,9 +294,12 @@ void BaseSamplingLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_t ...@@ -284,9 +294,12 @@ void BaseSamplingLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_t
if (step > 1 && repetition_penalty_type_ != RepetitionPenaltyType::None) { if (step > 1 && repetition_penalty_type_ != RepetitionPenaltyType::None) {
float default_value = getDefaultPenaltyValue(repetition_penalty_type_); float default_value = getDefaultPenaltyValue(repetition_penalty_type_);
if (!ALL_OF(repetition_penalty_ + ite * local_batch_size, local_batch_size, float, default_value)) { if (!ALL_OF(repetition_penalty_ + ite * local_batch_size, local_batch_size, float, default_value)) {
repetition_penalty_workspace_ = reinterpret_cast<int*>(allocator_->reMalloc(
repetition_penalty_workspace_, batch_size * step * (sizeof(int) + sizeof(float)), false));
invokeBatchApplyRepetitionPenalty( invokeBatchApplyRepetitionPenalty(
logits, logits,
repetition_penalty_buf_ + ite * local_batch_size, repetition_penalty_buf_ + ite * local_batch_size,
repetition_penalty_workspace_ + ite * local_batch_size,
output_tensors->at("output_ids").getPtrWithOffset<int>(ite * local_batch_size), output_tensors->at("output_ids").getPtrWithOffset<int>(ite * local_batch_size),
batch_size, batch_size,
local_batch_size, local_batch_size,
...@@ -300,10 +313,12 @@ void BaseSamplingLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_t ...@@ -300,10 +313,12 @@ void BaseSamplingLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_t
} }
} }
const int num_generated_tokens = step - max_input_length; const int num_generated_tokens = step - max_input_length;
const int* min_lengths = min_lengths_ + ite * local_batch_size; const int* min_lengths = min_lengths_ + ite * local_batch_size;
std::vector<int> index(local_batch_size);
std::iota(index.begin(), index.end(), 0);
const bool invoke_min_length_penalty = std::any_of( const bool invoke_min_length_penalty = std::any_of(
min_lengths, min_lengths + local_batch_size, [&](int min_length) { return min_length > num_generated_tokens; }); index.begin(), index.end(), [&](int i) { return min_lengths[i] > context_length_[i] + num_generated_tokens; });
if (invoke_min_length_penalty) { if (invoke_min_length_penalty) {
FT_CHECK_WITH_INFO(input_tensors->isExist("end_id"), "Need end_id to apply min length penlaty"); FT_CHECK_WITH_INFO(input_tensors->isExist("end_id"), "Need end_id to apply min length penlaty");
invokeMinLengthPenalty(logits, invokeMinLengthPenalty(logits,
......
...@@ -33,6 +33,8 @@ protected: ...@@ -33,6 +33,8 @@ protected:
size_t vocab_size_; size_t vocab_size_;
size_t vocab_size_padded_; size_t vocab_size_padded_;
int* repetition_penalty_workspace_;
size_t sampling_workspace_size_; size_t sampling_workspace_size_;
void* sampling_workspace_ = nullptr; void* sampling_workspace_ = nullptr;
...@@ -47,6 +49,7 @@ protected: ...@@ -47,6 +49,7 @@ protected:
int* min_lengths_ = nullptr; int* min_lengths_ = nullptr;
bool* skip_decode_ = nullptr; bool* skip_decode_ = nullptr;
bool skip_any_ = false; bool skip_any_ = false;
int* context_length_ = nullptr;
RepetitionPenaltyType repetition_penalty_type_ = RepetitionPenaltyType::None; RepetitionPenaltyType repetition_penalty_type_ = RepetitionPenaltyType::None;
......
...@@ -11,11 +11,28 @@ ...@@ -11,11 +11,28 @@
namespace turbomind { namespace turbomind {
BlockManager::BlockManager(size_t block_size, double block_count, int chunk_size, IAllocator* allocator): size_t GetSyncFreeMemSize(Barrier& barrier, std::atomic<size_t>& value)
{
size_t free{};
size_t total{};
check_cuda_error(cudaMemGetInfo(&free, &total));
// atomicMin
auto old = value.load();
while (old > free && !value.compare_exchange_weak(old, free)) {}
// wait for all ranks
barrier.wait();
return value.load();
}
BlockManager::BlockManager(
size_t block_size, double block_count, int chunk_size, IAllocator* allocator, GetFreeMemSize get_free_size):
block_size_(block_size), allocator_(allocator) block_size_(block_size), allocator_(allocator)
{ {
if (block_count < 1.) { if (block_count < 1.) {
max_block_count_ = GetBlockCount(block_size, block_count); max_block_count_ = GetBlockCount(block_size, block_count, get_free_size);
} }
else { else {
max_block_count_ = block_count; max_block_count_ = block_count;
...@@ -81,12 +98,10 @@ bool BlockManager::Malloc() ...@@ -81,12 +98,10 @@ bool BlockManager::Malloc()
return true; return true;
} }
size_t BlockManager::GetBlockCount(size_t block_size, double ratio) size_t BlockManager::GetBlockCount(size_t block_size, double ratio, GetFreeMemSize get_free_size)
{ {
size_t free{}; size_t free = get_free_size();
size_t total{}; return static_cast<size_t>(free * ratio) / block_size;
check_cuda_error(cudaMemGetInfo(&free, &total));
return static_cast<size_t>(total * ratio) / block_size;
} }
void BlockManager::Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst) void BlockManager::Move(std::vector<int>& src, const std::vector<int>& delta, std::vector<int>& dst)
......
...@@ -2,12 +2,15 @@ ...@@ -2,12 +2,15 @@
#pragma once #pragma once
#include "src/turbomind/models/llama/Barrier.h"
#include "src/turbomind/utils/allocator.h" #include "src/turbomind/utils/allocator.h"
#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/logger.h"
#include <algorithm> #include <algorithm>
#include <atomic>
#include <cstdint> #include <cstdint>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <functional>
#include <iterator> #include <iterator>
#include <numeric> #include <numeric>
#include <queue> #include <queue>
...@@ -63,9 +66,14 @@ struct Snapshot { ...@@ -63,9 +66,14 @@ struct Snapshot {
std::vector<int> use_count; std::vector<int> use_count;
}; };
using GetFreeMemSize = std::function<size_t()>;
size_t GetSyncFreeMemSize(Barrier& barrier, std::atomic<size_t>& value);
class BlockManager { class BlockManager {
public: public:
explicit BlockManager(size_t block_size, double block_count, int chunk_size, IAllocator* allocator); explicit BlockManager(
size_t block_size, double block_count, int chunk_size, IAllocator* allocator, GetFreeMemSize get_free_size);
~BlockManager(); ~BlockManager();
...@@ -124,7 +132,7 @@ public: ...@@ -124,7 +132,7 @@ public:
friend std::ostream& operator<<(std::ostream& os, const BlockManager&); friend std::ostream& operator<<(std::ostream& os, const BlockManager&);
private: private:
static size_t GetBlockCount(size_t block_size, double ratio); static size_t GetBlockCount(size_t block_size, double ratio, GetFreeMemSize get_free_size);
// move indices between sets // move indices between sets
static void Move(BlockIds& src, const BlockIds& delta, BlockIds& dst); static void Move(BlockIds& src, const BlockIds& delta, BlockIds& dst);
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "src/turbomind/kernels/decoding_kernels.h" #include "src/turbomind/kernels/decoding_kernels.h"
#include "src/turbomind/kernels/sampling_topk_kernels.h" #include "src/turbomind/kernels/sampling_topk_kernels.h"
#include "src/turbomind/macro.h" #include "src/turbomind/macro.h"
#include "src/turbomind/models/llama/BlockManager.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h" #include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/LlamaV2.h" #include "src/turbomind/models/llama/LlamaV2.h"
#include "src/turbomind/models/llama/Request.h" #include "src/turbomind/models/llama/Request.h"
...@@ -328,6 +329,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests) ...@@ -328,6 +329,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
} }
// total context length (history + input) // total context length (history + input)
state.h_prompt_length[idx] = output_ids - output_ids_base;
state.h_context_length[idx] = output_ids - output_ids_base; state.h_context_length[idx] = output_ids - output_ids_base;
state.h_finished[idx] = false; state.h_finished[idx] = false;
...@@ -365,6 +367,9 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests) ...@@ -365,6 +367,9 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
// scaling_factor = std::max(exp2f(ceilf(log2f((float)max_seq_len / max_pos_emb) + 1.f)) // scaling_factor = std::max(exp2f(ceilf(log2f((float)max_seq_len / max_pos_emb) + 1.f))
// - 1.f, 1.f); // - 1.f, 1.f);
} }
else {
scaling_factor = 1.f;
}
} }
if (scaling_factor != 1.f) { if (scaling_factor != 1.f) {
float rope_dim = model_->attn_params_.rotary_embedding_dim; float rope_dim = model_->attn_params_.rotary_embedding_dim;
...@@ -695,6 +700,7 @@ void LlamaBatch<T>::CopyState(const std::vector<std::tuple<BatchState*, BatchSta ...@@ -695,6 +700,7 @@ void LlamaBatch<T>::CopyState(const std::vector<std::tuple<BatchState*, BatchSta
} }
for (const auto& [s, d, si, di] : desc) { for (const auto& [s, d, si, di] : desc) {
d->h_prompt_length[di] = s->h_prompt_length[si];
d->h_context_length[di] = s->h_context_length[si]; d->h_context_length[di] = s->h_context_length[si];
d->h_finished[di] = s->h_finished[si]; d->h_finished[di] = s->h_finished[si];
d->h_rope_theta[di] = s->h_rope_theta[si]; d->h_rope_theta[di] = s->h_rope_theta[si];
...@@ -769,6 +775,7 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size) ...@@ -769,6 +775,7 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
h_bad_words_ = h_bad_words_ =
(int*)allocator_->reMalloc(h_bad_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true, true); (int*)allocator_->reMalloc(h_bad_words_, sizeof(int) * max_batch_size * kMaxStopBadWordsLen, true, true);
h_min_length_ = (int*)allocator_->reMalloc(h_min_length_, sizeof(int) * max_batch_size, true, true);
h_runtime_top_k_ = (int*)allocator_->reMalloc(h_runtime_top_k_, sizeof(int) * max_batch_size, true, true); h_runtime_top_k_ = (int*)allocator_->reMalloc(h_runtime_top_k_, sizeof(int) * max_batch_size, true, true);
h_runtime_top_p_ = (float*)allocator_->reMalloc(h_runtime_top_p_, sizeof(float) * max_batch_size, true, true); h_runtime_top_p_ = (float*)allocator_->reMalloc(h_runtime_top_p_, sizeof(float) * max_batch_size, true, true);
h_temperature_ = (float*)allocator_->reMalloc(h_temperature_, sizeof(float) * max_batch_size, true, true); h_temperature_ = (float*)allocator_->reMalloc(h_temperature_, sizeof(float) * max_batch_size, true, true);
...@@ -791,6 +798,7 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size) ...@@ -791,6 +798,7 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
sampling_params_ = { sampling_params_ = {
{"stop_words_list", (std::byte*)h_stop_words_, (std::byte*)d_stop_words_}, {"stop_words_list", (std::byte*)h_stop_words_, (std::byte*)d_stop_words_},
{"bad_words_list", (std::byte*)h_bad_words_, (std::byte*)d_bad_words_}, {"bad_words_list", (std::byte*)h_bad_words_, (std::byte*)d_bad_words_},
{"min_length", (std::byte*)h_min_length_, nullptr},
{"runtime_top_k", (std::byte*)h_runtime_top_k_, nullptr}, {"runtime_top_k", (std::byte*)h_runtime_top_k_, nullptr},
{"runtime_top_p", (std::byte*)h_runtime_top_p_, nullptr}, {"runtime_top_p", (std::byte*)h_runtime_top_p_, nullptr},
{"temperature", (std::byte*)h_temperature_, nullptr}, {"temperature", (std::byte*)h_temperature_, nullptr},
...@@ -823,6 +831,8 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size) ...@@ -823,6 +831,8 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size)
(uintptr_t*)allocator_->reMalloc(h_v_block_ptrs_, sizeof(uintptr_t) * max_block_count, false, true); (uintptr_t*)allocator_->reMalloc(h_v_block_ptrs_, sizeof(uintptr_t) * max_block_count, false, true);
for (auto& s : states_) { for (auto& s : states_) {
s.h_prompt_length =
(int*)allocator_->reMalloc(s.h_prompt_length, sizeof(int) * max_batch_size, false, true);
s.h_context_length = s.h_context_length =
(int*)allocator_->reMalloc(s.h_context_length, sizeof(int) * max_batch_size, false, true); (int*)allocator_->reMalloc(s.h_context_length, sizeof(int) * max_batch_size, false, true);
s.h_finished = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true); s.h_finished = (bool*)allocator_->reMalloc(s.h_finished, sizeof(bool) * max_batch_size * 2, false, true);
...@@ -943,6 +953,10 @@ LlamaBatch<T>::LlamaBatch(const EngineParams& params, int cache_block_seq_len, i ...@@ -943,6 +953,10 @@ LlamaBatch<T>::LlamaBatch(const EngineParams& params, int cache_block_seq_len, i
const size_t elem_bits = (quant_policy & QuantPolicy::kCacheKVInt8) ? 8 : sizeof(T) * 8; const size_t elem_bits = (quant_policy & QuantPolicy::kCacheKVInt8) ? 8 : sizeof(T) * 8;
auto get_free_size = [&] {
return GetSyncFreeMemSize(*model_->shared_state_->barrier, model_->shared_state_->free_size);
};
sequence_manager_.reset(new SequenceManager{model_->num_layer_, sequence_manager_.reset(new SequenceManager{model_->num_layer_,
model_->local_kv_head_num_, model_->local_kv_head_num_,
model_->size_per_head_, model_->size_per_head_,
...@@ -951,7 +965,8 @@ LlamaBatch<T>::LlamaBatch(const EngineParams& params, int cache_block_seq_len, i ...@@ -951,7 +965,8 @@ LlamaBatch<T>::LlamaBatch(const EngineParams& params, int cache_block_seq_len, i
params.cache_chunk_size, params.cache_chunk_size,
elem_bits, elem_bits,
model->tensor_para_.rank_, model->tensor_para_.rank_,
allocator_}); allocator_,
get_free_size});
const size_t max_session_len = sequence_manager_->max_block_count() * cache_block_seq_len; const size_t max_session_len = sequence_manager_->max_block_count() * cache_block_seq_len;
if (max_session_len < session_len_) { if (max_session_len < session_len_) {
...@@ -1054,6 +1069,12 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g) ...@@ -1054,6 +1069,12 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g)
} }
} }
// MinLengthPenalty
if (inputs.isExist("min_length")) {
inputs.insert({"prompt_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_prompt_length}});
inputs.insert({"context_length", {MEMORY_CPU, TYPE_INT32, {(size_t)batch_size}, state_->h_context_length}});
}
// init for eos // init for eos
std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_); std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_);
Copy(h_end_ids_buf_, batch_size, d_end_ids_buf_); Copy(h_end_ids_buf_, batch_size, d_end_ids_buf_);
...@@ -1065,9 +1086,10 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g) ...@@ -1065,9 +1086,10 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g)
} }
template<typename T> template<typename T>
void LlamaBatch<T>::OutputContextLogits(T* context_decoder_output, void LlamaBatch<T>::OutputContextLogits(T* context_decoder_output,
const std::vector<int>& indices, const std::vector<int>& indices,
const std::vector<int>& lengths) const std::vector<int>& lengths,
const std::vector<const Sequence*>& sequences)
{ {
std::vector<float*> output_logits; std::vector<float*> output_logits;
int num_token = 0; int num_token = 0;
...@@ -1075,7 +1097,11 @@ void LlamaBatch<T>::OutputContextLogits(T* context_decoder_ ...@@ -1075,7 +1097,11 @@ void LlamaBatch<T>::OutputContextLogits(T* context_decoder_
bool is_return_logits = false; bool is_return_logits = false;
for (int k = 0; k < indices.size(); ++k) { for (int k = 0; k < indices.size(); ++k) {
auto& request = state_->requests[indices[k]]; auto& request = state_->requests[indices[k]];
output_logits.push_back(request->outputs[rank_].getPtr<float>("logits", nullptr)); auto logits = request->outputs[rank_].getPtr<float>("logits", nullptr);
if (logits && sequences[k]->cache_len + lengths[k] <= sequences[k]->tokens.size()) {
logits = nullptr;
}
output_logits.push_back(logits);
num_token += lengths[k]; num_token += lengths[k];
if (output_logits.back()) { if (output_logits.back()) {
is_return_logits = true; is_return_logits = true;
...@@ -1095,7 +1121,7 @@ void LlamaBatch<T>::OutputContextLogits(T* context_decoder_ ...@@ -1095,7 +1121,7 @@ void LlamaBatch<T>::OutputContextLogits(T* context_decoder_
FT_CHECK(model_->vocab_size_padded_ % tp == 0); FT_CHECK(model_->vocab_size_padded_ % tp == 0);
const auto local_vocab_size = model_->vocab_size_padded_ / tp; const auto local_vocab_size = model_->vocab_size_padded_ / tp;
local_context_logits_buf_ = local_context_logits_buf_ =
(float*)allocator_->malloc(sizeof(float) * local_vocab_size * max_context_token_num_); (float*)allocator_->malloc(sizeof(float) * model_->vocab_size_padded_ * max_context_token_num_);
} }
} }
...@@ -1105,7 +1131,27 @@ void LlamaBatch<T>::OutputContextLogits(T* context_decoder_ ...@@ -1105,7 +1131,27 @@ void LlamaBatch<T>::OutputContextLogits(T* context_decoder_
for (int k = 0; k < indices.size(); ++k) { for (int k = 0; k < indices.size(); ++k) {
if (output_logits[k]) { if (output_logits[k]) {
Copy(logits, model_->vocab_size_ * lengths[k], output_logits[k]); auto src_ptr = logits;
auto dst_ptr = output_logits[k];
int num_new_token = 0;
if (sequences[k]->cache_len < sequences[k]->tokens.size()) {
num_new_token = sequences[k]->cache_len + lengths[k] - sequences[k]->tokens.size();
src_ptr += (lengths[k] - num_new_token) * model_->vocab_size_padded_;
}
else {
num_new_token = lengths[k];
dst_ptr += (sequences[k]->cache_len - sequences[k]->tokens.size()) * model_->vocab_size_;
}
if (model_->vocab_size_padded_ == model_->vocab_size_) {
Copy(src_ptr, model_->vocab_size_ * num_new_token, dst_ptr);
}
else {
for (int tok = 0; tok < num_new_token; tok++) {
Copy(src_ptr, model_->vocab_size_, dst_ptr);
src_ptr += model_->vocab_size_padded_;
dst_ptr += model_->vocab_size_;
}
}
} }
logits += model_->vocab_size_padded_ * lengths[k]; logits += model_->vocab_size_padded_ * lengths[k];
} }
...@@ -1562,7 +1608,7 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter) ...@@ -1562,7 +1608,7 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
if (iter == 0) { if (iter == 0) {
// compute logits of inputs if requested // compute logits of inputs if requested
OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths); OutputContextLogits(context_decoder_output_buf_, decode_indices, decode_lengths, sequences);
} }
} }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
namespace turbomind { namespace turbomind {
struct BatchState { struct BatchState {
int* h_prompt_length; // history + input, ignore generated
int* h_context_length; int* h_context_length;
bool* h_finished; bool* h_finished;
...@@ -92,8 +93,10 @@ public: ...@@ -92,8 +93,10 @@ public:
[[nodiscard]] Signal Interrupt(int index, bool force_stop = false, bool force_end = false); [[nodiscard]] Signal Interrupt(int index, bool force_stop = false, bool force_end = false);
void void OutputContextLogits(T* context_decoder_output,
OutputContextLogits(T* context_decoder_output, const std::vector<int>& indices, const std::vector<int>& lengths); const std::vector<int>& indices,
const std::vector<int>& lengths,
const std::vector<const Sequence*>& sequences);
explicit LlamaBatch(const EngineParams& params, int cache_block_seq_len, int quant_policy, LlamaV2<T>* model); explicit LlamaBatch(const EngineParams& params, int cache_block_seq_len, int quant_policy, LlamaV2<T>* model);
...@@ -246,6 +249,7 @@ private: ...@@ -246,6 +249,7 @@ private:
uintptr_t* h_k_block_ptrs_{}; uintptr_t* h_k_block_ptrs_{};
uintptr_t* h_v_block_ptrs_{}; uintptr_t* h_v_block_ptrs_{};
int* h_min_length_{};
int* h_runtime_top_k_{}; int* h_runtime_top_k_{};
float* h_runtime_top_p_{}; float* h_runtime_top_p_{};
float* h_temperature_{}; float* h_temperature_{};
......
...@@ -177,11 +177,17 @@ void LlamaV2<T>::updateEmbedding(T* decoder_input, const int bsz, const int* h_i ...@@ -177,11 +177,17 @@ void LlamaV2<T>::updateEmbedding(T* decoder_input, const int bsz, const int* h_i
for (int j = embeddings.size() - 1; j >= 0; j--) { for (int j = embeddings.size() - 1; j >= 0; j--) {
int begin = ranges[j].first; int begin = ranges[j].first;
int end = ranges[j].second; int end = ranges[j].second;
if (seq.cache_len + h_input_length[i] - 1 < begin) {
continue;
}
if (end <= seq.cache_len) { if (end <= seq.cache_len) {
break; break;
} }
int off_dst = std::max(0, begin - seq.cache_len); int off_dst = std::max(0, begin - seq.cache_len);
int off_src = std::max(0, seq.cache_len - begin); int off_src = std::max(0, seq.cache_len - begin);
// calculate intersection of [begin, end) and [seq.cache_len, seq.cache_len + h_input_length[i])
begin = std::max(begin, seq.cache_len);
end = std::min(end, seq.cache_len + h_input_length[i]);
size_t byte_size = (end - begin) * hidden_units_ * sizeof(T); size_t byte_size = (end - begin) * hidden_units_ * sizeof(T);
T* dst_ptr = decoder_input + off_dst * hidden_units_; T* dst_ptr = decoder_input + off_dst * hidden_units_;
auto src_ptr = embeddings[j].data() + off_src * hidden_units_ * sizeof(T); auto src_ptr = embeddings[j].data() + off_src * hidden_units_ * sizeof(T);
......
...@@ -33,6 +33,7 @@ ...@@ -33,6 +33,7 @@
#include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/instance_comm.h" #include "src/turbomind/utils/instance_comm.h"
#include "src/turbomind/utils/nccl_utils.h" #include "src/turbomind/utils/nccl_utils.h"
#include <limits>
#include <unordered_map> #include <unordered_map>
using ffi_api_lock_ctrl_t = std::function<void(int)>; using ffi_api_lock_ctrl_t = std::function<void(int)>;
...@@ -48,6 +49,7 @@ public: ...@@ -48,6 +49,7 @@ public:
RequestQueue request_queue; RequestQueue request_queue;
std::shared_ptr<Barrier> barrier; std::shared_ptr<Barrier> barrier;
bool abort; bool abort;
std::atomic<size_t> free_size{std::numeric_limits<size_t>::max()};
}; };
~LlamaV2(); ~LlamaV2();
......
...@@ -13,23 +13,24 @@ ...@@ -13,23 +13,24 @@
namespace turbomind { namespace turbomind {
SequenceManager::SequenceManager(size_t layer_num, SequenceManager::SequenceManager(size_t layer_num,
size_t head_num, size_t head_num,
size_t head_dim, size_t head_dim,
size_t block_seq_len, size_t block_seq_len,
double block_count, double block_count,
int chunk_size, int chunk_size,
size_t elem_bits, size_t elem_bits,
int rank, int rank,
IAllocator* allocator): IAllocator* allocator,
block_seq_len_(block_seq_len) GetFreeMemSize get_free_size):
block_seq_len_(block_seq_len), rank_(rank)
{ {
constexpr int kBitsPerByte = 8; constexpr int kBitsPerByte = 8;
// [2, L, H, block_seq_len, D] // [2, L, H, block_seq_len, D]
size_t block_size = 2UL * layer_num * head_num * block_seq_len * head_dim * elem_bits / kBitsPerByte; size_t block_size = 2UL * layer_num * head_num * block_seq_len * head_dim * elem_bits / kBitsPerByte;
block_manager_ = std::make_unique<BlockManager>(block_size, block_count, chunk_size, allocator); block_manager_ = std::make_unique<BlockManager>(block_size, block_count, chunk_size, allocator, get_free_size);
val_offset_ = block_size / 2; val_offset_ = block_size / 2;
} }
......
...@@ -54,15 +54,16 @@ inline std::ostream& operator<<(std::ostream& os, const Sequence& seq) ...@@ -54,15 +54,16 @@ inline std::ostream& operator<<(std::ostream& os, const Sequence& seq)
class SequenceManager { class SequenceManager {
public: public:
explicit SequenceManager(size_t layer_num, explicit SequenceManager(size_t layer_num,
size_t head_num, size_t head_num,
size_t head_dim, size_t head_dim,
size_t block_seq_len, size_t block_seq_len,
double block_count, double block_count,
int chunk_size, int chunk_size,
size_t elem_bits, size_t elem_bits,
int rank, int rank,
IAllocator* allocator); IAllocator* allocator,
GetFreeMemSize get_free_size);
SequenceManager(const SequenceManager&) = delete; SequenceManager(const SequenceManager&) = delete;
SequenceManager(SequenceManager&&) noexcept = default; SequenceManager(SequenceManager&&) noexcept = default;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment